respond to progress create and apply edit
[lsp-test.git] / src / Language / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE GADTs             #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
10
11 module Language.LSP.Test.Session
12   ( Session(..)
13   , SessionConfig(..)
14   , defaultConfig
15   , SessionMessage(..)
16   , SessionContext(..)
17   , SessionState(..)
18   , runSession'
19   , get
20   , put
21   , modify
22   , modifyM
23   , ask
24   , asks
25   , sendMessage
26   , updateState
27   , withTimeout
28   , getCurTimeoutId
29   , bumpTimeoutId
30   , logMsg
31   , LogMsgType(..)
32   , documentChangeUri
33   )
34
35 where
36
37 import Control.Applicative
38 import Control.Concurrent hiding (yield)
39 import Control.Exception
40 import Control.Lens hiding (List)
41 import Control.Monad
42 import Control.Monad.IO.Class
43 import Control.Monad.Except
44 #if __GLASGOW_HASKELL__ == 806
45 import Control.Monad.Fail
46 #endif
47 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
48 import qualified Control.Monad.Trans.Reader as Reader (ask)
49 import Control.Monad.Trans.State (StateT, runStateT)
50 import qualified Control.Monad.Trans.State as State
51 import qualified Data.ByteString.Lazy.Char8 as B
52 import Data.Aeson
53 import Data.Aeson.Encode.Pretty
54 import Data.Conduit as Conduit
55 import Data.Conduit.Parser as Parser
56 import Data.Default
57 import Data.Foldable
58 import Data.List
59 import qualified Data.Map as Map
60 import qualified Data.Set as Set
61 import qualified Data.Text as T
62 import qualified Data.Text.IO as T
63 import qualified Data.HashMap.Strict as HashMap
64 import Data.Maybe
65 import Data.Function
66 import Language.LSP.Types.Capabilities
67 import Language.LSP.Types
68 import Language.LSP.Types.Lens
69 import qualified Language.LSP.Types.Lens as LSP
70 import Language.LSP.VFS
71 import Language.LSP.Test.Compat
72 import Language.LSP.Test.Decoding
73 import Language.LSP.Test.Exceptions
74 import System.Console.ANSI
75 import System.Directory
76 import System.IO
77 import System.Process (ProcessHandle())
78 #ifndef mingw32_HOST_OS
79 import System.Process (waitForProcess)
80 #endif
81 import System.Timeout
82
83 -- | A session representing one instance of launching and connecting to a server.
84 --
85 -- You can send and receive messages to the server within 'Session' via
86 -- 'Language.LSP.Test.message',
87 -- 'Language.LSP.Test.sendRequest' and
88 -- 'Language.LSP.Test.sendNotification'.
89
90 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
91   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
92
93 #if __GLASGOW_HASKELL__ >= 806
94 instance MonadFail Session where
95   fail s = do
96     lastMsg <- fromJust . lastReceivedMessage <$> get
97     liftIO $ throw (UnexpectedMessage s lastMsg)
98 #endif
99
100 -- | Stuff you can configure for a 'Session'.
101 data SessionConfig = SessionConfig
102   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
103   , logStdErr      :: Bool
104   -- ^ Redirect the server's stderr to this stdout, defaults to False.
105   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
106   , logMessages    :: Bool
107   -- ^ Trace the messages sent and received to stdout, defaults to False.
108   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
109   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
110   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
111   , ignoreLogNotifications :: Bool
112   -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
113   -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
114   --
115   -- @since 0.9.0.0
116   , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
117   -- ^ The initial workspace folders to send in the @initialize@ request.
118   -- Defaults to Nothing.
119   }
120
121 -- | The configuration used in 'Language.LSP.Test.runSession'.
122 defaultConfig :: SessionConfig
123 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
124
125 instance Default SessionConfig where
126   def = defaultConfig
127
128 data SessionMessage = ServerMessage FromServerMessage
129                     | TimeoutMessage Int
130   deriving Show
131
132 data SessionContext = SessionContext
133   {
134     serverIn :: Handle
135   , rootDir :: FilePath
136   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
137   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
138   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
139   , requestMap :: MVar RequestMap
140   , initRsp :: MVar (ResponseMessage Initialize)
141   , config :: SessionConfig
142   , sessionCapabilities :: ClientCapabilities
143   }
144
145 class Monad m => HasReader r m where
146   ask :: m r
147   asks :: (r -> b) -> m b
148   asks f = f <$> ask
149
150 instance HasReader SessionContext Session where
151   ask  = Session (lift $ lift Reader.ask)
152
153 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
154   ask = lift $ lift Reader.ask
155
156 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
157 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
158
159 -- Pass this the timeoutid you *were* waiting on
160 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
161 bumpTimeoutId prev = do
162   v <- asks curTimeoutId
163   -- when updating the curtimeoutid, account for the fact that something else
164   -- might have bumped the timeoutid in the meantime
165   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
166
167 data SessionState = SessionState
168   {
169     curReqId :: Int
170   , vfs :: VFS
171   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
172   , overridingTimeout :: Bool
173   -- ^ The last received message from the server.
174   -- Used for providing exception information
175   , lastReceivedMessage :: Maybe FromServerMessage
176   , curDynCaps :: Map.Map T.Text SomeRegistration
177   -- ^ The capabilities that the server has dynamically registered with us so
178   -- far
179   , curProgressSessions :: Set.Set ProgressToken
180   }
181
182 class Monad m => HasState s m where
183   get :: m s
184
185   put :: s -> m ()
186
187   modify :: (s -> s) -> m ()
188   modify f = get >>= put . f
189
190   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
191   modifyM f = get >>= f >>= put
192
193 instance HasState SessionState Session where
194   get = Session (lift State.get)
195   put = Session . lift . State.put
196
197 instance Monad m => HasState s (StateT s m) where
198   get = State.get
199   put = State.put
200
201 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
202  where
203   get = lift get
204   put = lift . put
205
206 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
207  where
208   get = lift get
209   put = lift . put
210
211 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
212 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
213   where
214     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
215
216     handler (Unexpected "ConduitParser.empty") = do
217       lastMsg <- fromJust . lastReceivedMessage <$> get
218       name <- getParserName
219       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
220
221     handler e = throw e
222
223     chanSource = do
224       msg <- liftIO $ readChan (messageChan context)
225       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
226         yield msg
227       chanSource
228
229     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
230     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
231     isLogNotification _ = False
232
233     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
234     watchdog = Conduit.awaitForever $ \msg -> do
235       curId <- getCurTimeoutId
236       case msg of
237         ServerMessage sMsg -> yield sMsg
238         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
239
240 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
241 -- It also does not automatically send initialize and exit messages.
242 runSession' :: Handle -- ^ Server in
243             -> Handle -- ^ Server out
244             -> Maybe ProcessHandle -- ^ Server process
245             -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
246             -> SessionConfig
247             -> ClientCapabilities
248             -> FilePath -- ^ Root directory
249             -> Session () -- ^ To exit the Server properly
250             -> Session a
251             -> IO a
252 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
253   absRootDir <- canonicalizePath rootDir
254
255   hSetBuffering serverIn  NoBuffering
256   hSetBuffering serverOut NoBuffering
257   -- This is required to make sure that we don’t get any
258   -- newline conversion or weird encoding issues.
259   hSetBinaryMode serverIn True
260   hSetBinaryMode serverOut True
261
262   reqMap <- newMVar newRequestMap
263   messageChan <- newChan
264   timeoutIdVar <- newMVar 0
265   initRsp <- newEmptyMVar
266
267   mainThreadId <- myThreadId
268
269   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
270       initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
271       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
272
273       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
274       serverListenerLauncher =
275         forkIO $ catch (serverHandler serverOut context) errorHandler
276       msgTimeoutMs = messageTimeout config * 10^6
277       serverAndListenerFinalizer tid = do
278         let cleanup
279               | Just sp <- mServerProc = do
280                   -- Give the server some time to exit cleanly
281                   -- It makes the server hangs in windows so we have to avoid it
282 #ifndef mingw32_HOST_OS
283                   timeout msgTimeoutMs (waitForProcess sp)
284 #endif
285                   cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
286               | otherwise = pure ()
287         finally (timeout msgTimeoutMs (runSession' exitServer))
288                 -- Make sure to kill the listener first, before closing
289                 -- handles etc via cleanupProcess
290                 (killThread tid >> cleanup)
291
292   (result, _) <- bracket serverListenerLauncher
293                          serverAndListenerFinalizer
294                          (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
295   return result
296
297 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
298 updateStateC = awaitForever $ \msg -> do
299   updateState msg
300   yield msg
301
302 -- extract Uri out from DocumentChange
303 -- didn't put this in `lsp-types` because TH was getting in the way
304 documentChangeUri :: DocumentChange -> Uri
305 documentChangeUri (InL x) = x ^. textDocument . uri
306 documentChangeUri (InR (InL x)) = x ^. uri
307 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
308 documentChangeUri (InR (InR (InR x))) = x ^. uri
309
310 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
311             => FromServerMessage -> m ()
312 updateState (FromServerMess SWindowWorkDoneProgressCreate req) =
313   sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
314 updateState (FromServerMess SProgress req) = case req ^. params . value of
315   Begin _ ->
316     modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
317   End _ ->
318     modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
319   _ -> pure ()
320
321 -- Keep track of dynamic capability registration
322 updateState (FromServerMess SClientRegisterCapability req) = do
323   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
324   modify $ \s ->
325     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
326
327 updateState (FromServerMess SClientUnregisterCapability req) = do
328   let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
329   modify $ \s ->
330     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
331     in s { curDynCaps = newCurDynCaps }
332
333 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
334   let List diags = n ^. params . diagnostics
335       doc = n ^. params . uri
336   modify $ \s ->
337     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
338       in s { curDiagnostics = newDiags }
339
340 updateState (FromServerMess SWorkspaceApplyEdit r) = do
341
342   -- First, prefer the versioned documentChanges field
343   allChangeParams <- case r ^. params . edit . documentChanges of
344     Just (List cs) -> do
345       mapM_ (checkIfNeedsOpened . documentChangeUri) cs
346       return $ mapMaybe getParamsFromDocumentChange cs
347     -- Then fall back to the changes field
348     Nothing -> case r ^. params . edit . changes of
349       Just cs -> do
350         mapM_ checkIfNeedsOpened (HashMap.keys cs)
351         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
352       Nothing ->
353         error "WorkspaceEdit contains neither documentChanges nor changes!"
354
355   modifyM $ \s -> do
356     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
357     return $ s { vfs = newVFS }
358
359   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
360       mergedParams = map mergeParams groupedParams
361
362   -- TODO: Don't do this when replaying a session
363   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
364
365   sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
366
367   -- Update VFS to new document versions
368   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
369       latestVersions = map ((^. textDocument) . last) sortedVersions
370       bumpedVersions = map (version . _Just +~ 1) latestVersions
371
372   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
373     modify $ \s ->
374       let oldVFS = vfs s
375           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
376           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
377       in s { vfs = newVFS }
378
379   where checkIfNeedsOpened uri = do
380           oldVFS <- vfs <$> get
381           ctx <- ask
382
383           -- if its not open, open it
384           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
385             let fp = fromJust $ uriToFilePath uri
386             contents <- liftIO $ T.readFile fp
387             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
388                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
389             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
390
391             modifyM $ \s -> do
392               let (newVFS,_) = openVFS (vfs s) msg
393               return $ s { vfs = newVFS }
394
395         getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
396         getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) = 
397           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
398             in DidChangeTextDocumentParams docId (List changeEvents)
399
400         getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
401         getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
402         getParamsFromDocumentChange _ = Nothing
403
404
405         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
406         -- where n is the current version
407         textDocumentVersions uri = do
408           m <- vfsMap . vfs <$> get
409           let curVer = fromMaybe 0 $
410                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
411           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
412
413         textDocumentEdits uri edits = do
414           vers <- textDocumentVersions uri
415           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
416
417         getChangeParams uri (List edits) = do 
418           map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
419
420         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
421         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
422                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
423 updateState _ = return ()
424
425 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
426 sendMessage msg = do
427   h <- serverIn <$> ask
428   logMsg LogClient msg
429   liftIO $ B.hPut h (addHeader $ encode msg)
430
431 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
432 -- after duration seconds. This will override the global timeout
433 -- for waiting for messages to arrive defined in 'SessionConfig'.
434 withTimeout :: Int -> Session a -> Session a
435 withTimeout duration f = do
436   chan <- asks messageChan
437   timeoutId <- getCurTimeoutId
438   modify $ \s -> s { overridingTimeout = True }
439   liftIO $ forkIO $ do
440     threadDelay (duration * 1000000)
441     writeChan chan (TimeoutMessage timeoutId)
442   res <- f
443   bumpTimeoutId timeoutId
444   modify $ \s -> s { overridingTimeout = False }
445   return res
446
447 data LogMsgType = LogServer | LogClient
448   deriving Eq
449
450 -- | Logs the message if the config specified it
451 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
452        => LogMsgType -> a -> m ()
453 logMsg t msg = do
454   shouldLog <- asks $ logMessages . config
455   shouldColor <- asks $ logColor . config
456   liftIO $ when shouldLog $ do
457     when shouldColor $ setSGR [SetColor Foreground Dull color]
458     putStrLn $ arrow ++ showPretty msg
459     when shouldColor $ setSGR [Reset]
460
461   where arrow
462           | t == LogServer  = "<-- "
463           | otherwise       = "--> "
464         color
465           | t == LogServer  = Magenta
466           | otherwise       = Cyan
467
468         showPretty = B.unpack . encodePretty