X-Git-Url: https://git.lukelau.me/?a=blobdiff_plain;f=src%2FLanguage%2FLSP%2FTest%2FSession.hs;h=f3d6f8c28e83c7250bd093bde204e79964e08204;hb=6042c31e8b18eefb81b98a8ebb3e1e6f4a004907;hp=6a6cf15a5a484d90a10a4da50217ab277b5d6c5f;hpb=1fb2c02419384b450fd43ae281ef410cb7bfb2cf;p=lsp-test.git diff --git a/src/Language/LSP/Test/Session.hs b/src/Language/LSP/Test/Session.hs index 6a6cf15..f3d6f8c 100644 --- a/src/Language/LSP/Test/Session.hs +++ b/src/Language/LSP/Test/Session.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -6,6 +7,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeInType #-} module Language.LSP.Test.Session ( Session(..) @@ -28,6 +30,7 @@ module Language.LSP.Test.Session , bumpTimeoutId , logMsg , LogMsgType(..) + , documentChangeUri ) where @@ -54,7 +57,8 @@ import Data.Conduit.Parser as Parser import Data.Default import Data.Foldable import Data.List -import qualified Data.Map as Map +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set import qualified Data.Text as T import qualified Data.Text.IO as T import qualified Data.HashMap.Strict as HashMap @@ -76,6 +80,7 @@ import System.Process (ProcessHandle()) import System.Process (waitForProcess) #endif import System.Timeout +import Data.IORef -- | A session representing one instance of launching and connecting to a server. -- @@ -132,9 +137,9 @@ data SessionContext = SessionContext , rootDir :: FilePath , messageChan :: Chan SessionMessage -- ^ Where all messages come through -- Keep curTimeoutId in SessionContext, as its tied to messageChan - , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on + , curTimeoutId :: IORef Int -- ^ The current timeout we are waiting on , requestMap :: MVar RequestMap - , initRsp :: MVar InitializeResponse + , initRsp :: MVar (ResponseMessage Initialize) , config :: SessionConfig , sessionCapabilities :: ClientCapabilities } @@ -151,7 +156,7 @@ instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where ask = lift $ lift Reader.ask getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int -getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar +getCurTimeoutId = asks curTimeoutId >>= liftIO . readIORef -- Pass this the timeoutid you *were* waiting on bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m () @@ -159,20 +164,21 @@ bumpTimeoutId prev = do v <- asks curTimeoutId -- when updating the curtimeoutid, account for the fact that something else -- might have bumped the timeoutid in the meantime - liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1))) + liftIO $ atomicModifyIORef' v (\x -> (max x (prev + 1), ())) data SessionState = SessionState { - curReqId :: Int - , vfs :: VFS - , curDiagnostics :: Map.Map NormalizedUri [Diagnostic] - , overridingTimeout :: Bool + curReqId :: !Int + , vfs :: !VFS + , curDiagnostics :: !(Map.Map NormalizedUri [Diagnostic]) + , overridingTimeout :: !Bool -- ^ The last received message from the server. -- Used for providing exception information - , lastReceivedMessage :: Maybe FromServerMessage - , curDynCaps :: Map.Map T.Text SomeRegistration + , lastReceivedMessage :: !(Maybe FromServerMessage) + , curDynCaps :: !(Map.Map T.Text SomeRegistration) -- ^ The capabilities that the server has dynamically registered with us so -- far + , curProgressSessions :: !(Set.Set ProgressToken) } class Monad m => HasState s m where @@ -257,13 +263,13 @@ runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exi reqMap <- newMVar newRequestMap messageChan <- newChan - timeoutIdVar <- newMVar 0 + timeoutIdVar <- newIORef 0 initRsp <- newEmptyMVar mainThreadId <- myThreadId let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps - initState vfs = SessionState 0 vfs mempty False Nothing mempty + initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses errorHandler = throwTo mainThreadId :: SessionException -> IO () @@ -293,10 +299,33 @@ runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exi updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () updateStateC = awaitForever $ \msg -> do updateState msg + respond msg yield msg + where + respond :: (MonadIO m, HasReader SessionContext m) => FromServerMessage -> m () + respond (FromServerMess SWindowWorkDoneProgressCreate req) = + sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ()) + respond (FromServerMess SWorkspaceApplyEdit r) = do + sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing) + respond _ = pure () + + +-- extract Uri out from DocumentChange +-- didn't put this in `lsp-types` because TH was getting in the way +documentChangeUri :: DocumentChange -> Uri +documentChangeUri (InL x) = x ^. textDocument . uri +documentChangeUri (InR (InL x)) = x ^. uri +documentChangeUri (InR (InR (InL x))) = x ^. oldUri +documentChangeUri (InR (InR (InR x))) = x ^. uri updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m () +updateState (FromServerMess SProgress req) = case req ^. params . value of + Begin _ -> + modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s } + End _ -> + modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s } + _ -> pure () -- Keep track of dynamic capability registration updateState (FromServerMess SClientRegisterCapability req) = do @@ -322,8 +351,8 @@ updateState (FromServerMess SWorkspaceApplyEdit r) = do -- First, prefer the versioned documentChanges field allChangeParams <- case r ^. params . edit . documentChanges of Just (List cs) -> do - mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs - return $ map getParams cs + mapM_ (checkIfNeedsOpened . documentChangeUri) cs + return $ mapMaybe getParamsFromDocumentChange cs -- Then fall back to the changes field Nothing -> case r ^. params . edit . changes of Just cs -> do @@ -370,10 +399,16 @@ updateState (FromServerMess SWorkspaceApplyEdit r) = do let (newVFS,_) = openVFS (vfs s) msg return $ s { vfs = newVFS } - getParams (TextDocumentEdit docId (List edits)) = + getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams + getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) = let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits in DidChangeTextDocumentParams docId (List changeEvents) + getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams + getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit + getParamsFromDocumentChange _ = Nothing + + -- For a uri returns an infinite list of versions [n,n+1,n+2,...] -- where n is the current version textDocumentVersions uri = do @@ -386,8 +421,8 @@ updateState (FromServerMess SWorkspaceApplyEdit r) = do vers <- textDocumentVersions uri pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits - getChangeParams uri (List edits) = - map <$> pure getParams <*> textDocumentEdits uri (reverse edits) + getChangeParams uri (List edits) = do + map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits) mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params)) @@ -408,10 +443,11 @@ withTimeout duration f = do chan <- asks messageChan timeoutId <- getCurTimeoutId modify $ \s -> s { overridingTimeout = True } - liftIO $ forkIO $ do + tid <- liftIO $ forkIO $ do threadDelay (duration * 1000000) writeChan chan (TimeoutMessage timeoutId) res <- f + liftIO $ killThread tid bumpTimeoutId timeoutId modify $ \s -> s { overridingTimeout = False } return res