track progress sessions
[lsp-test.git] / src / Language / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE GADTs             #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
10
11 module Language.LSP.Test.Session
12   ( Session(..)
13   , SessionConfig(..)
14   , defaultConfig
15   , SessionMessage(..)
16   , SessionContext(..)
17   , SessionState(..)
18   , runSession'
19   , get
20   , put
21   , modify
22   , modifyM
23   , ask
24   , asks
25   , sendMessage
26   , updateState
27   , withTimeout
28   , getCurTimeoutId
29   , bumpTimeoutId
30   , logMsg
31   , LogMsgType(..)
32   , documentChangeUri
33   )
34
35 where
36
37 import Control.Applicative
38 import Control.Concurrent hiding (yield)
39 import Control.Exception
40 import Control.Lens hiding (List)
41 import Control.Monad
42 import Control.Monad.IO.Class
43 import Control.Monad.Except
44 #if __GLASGOW_HASKELL__ == 806
45 import Control.Monad.Fail
46 #endif
47 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
48 import qualified Control.Monad.Trans.Reader as Reader (ask)
49 import Control.Monad.Trans.State (StateT, runStateT)
50 import qualified Control.Monad.Trans.State as State
51 import qualified Data.ByteString.Lazy.Char8 as B
52 import Data.Aeson
53 import Data.Aeson.Encode.Pretty
54 import Data.Conduit as Conduit
55 import Data.Conduit.Parser as Parser
56 import Data.Default
57 import Data.Foldable
58 import Data.List
59 import qualified Data.Map as Map
60 import qualified Data.Set as Set
61 import qualified Data.Text as T
62 import qualified Data.Text.IO as T
63 import qualified Data.HashMap.Strict as HashMap
64 import Data.Maybe
65 import Data.Function
66 import Language.LSP.Types.Capabilities
67 import Language.LSP.Types
68 import Language.LSP.Types.Lens
69 import qualified Language.LSP.Types.Lens as LSP
70 import Language.LSP.VFS
71 import Language.LSP.Test.Compat
72 import Language.LSP.Test.Decoding
73 import Language.LSP.Test.Exceptions
74 import System.Console.ANSI
75 import System.Directory
76 import System.IO
77 import System.Process (ProcessHandle())
78 #ifndef mingw32_HOST_OS
79 import System.Process (waitForProcess)
80 #endif
81 import System.Timeout
82
83 -- | A session representing one instance of launching and connecting to a server.
84 --
85 -- You can send and receive messages to the server within 'Session' via
86 -- 'Language.LSP.Test.message',
87 -- 'Language.LSP.Test.sendRequest' and
88 -- 'Language.LSP.Test.sendNotification'.
89
90 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
91   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
92
93 #if __GLASGOW_HASKELL__ >= 806
94 instance MonadFail Session where
95   fail s = do
96     lastMsg <- fromJust . lastReceivedMessage <$> get
97     liftIO $ throw (UnexpectedMessage s lastMsg)
98 #endif
99
100 -- | Stuff you can configure for a 'Session'.
101 data SessionConfig = SessionConfig
102   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
103   , logStdErr      :: Bool
104   -- ^ Redirect the server's stderr to this stdout, defaults to False.
105   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
106   , logMessages    :: Bool
107   -- ^ Trace the messages sent and received to stdout, defaults to False.
108   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
109   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
110   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
111   , ignoreLogNotifications :: Bool
112   -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
113   -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
114   --
115   -- @since 0.9.0.0
116   , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
117   -- ^ The initial workspace folders to send in the @initialize@ request.
118   -- Defaults to Nothing.
119   }
120
121 -- | The configuration used in 'Language.LSP.Test.runSession'.
122 defaultConfig :: SessionConfig
123 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
124
125 instance Default SessionConfig where
126   def = defaultConfig
127
128 data SessionMessage = ServerMessage FromServerMessage
129                     | TimeoutMessage Int
130   deriving Show
131
132 data SessionContext = SessionContext
133   {
134     serverIn :: Handle
135   , rootDir :: FilePath
136   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
137   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
138   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
139   , requestMap :: MVar RequestMap
140   , initRsp :: MVar (ResponseMessage Initialize)
141   , config :: SessionConfig
142   , sessionCapabilities :: ClientCapabilities
143   }
144
145 class Monad m => HasReader r m where
146   ask :: m r
147   asks :: (r -> b) -> m b
148   asks f = f <$> ask
149
150 instance HasReader SessionContext Session where
151   ask  = Session (lift $ lift Reader.ask)
152
153 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
154   ask = lift $ lift Reader.ask
155
156 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
157 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
158
159 -- Pass this the timeoutid you *were* waiting on
160 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
161 bumpTimeoutId prev = do
162   v <- asks curTimeoutId
163   -- when updating the curtimeoutid, account for the fact that something else
164   -- might have bumped the timeoutid in the meantime
165   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
166
167 data SessionState = SessionState
168   {
169     curReqId :: Int
170   , vfs :: VFS
171   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
172   , overridingTimeout :: Bool
173   -- ^ The last received message from the server.
174   -- Used for providing exception information
175   , lastReceivedMessage :: Maybe FromServerMessage
176   , curDynCaps :: Map.Map T.Text SomeRegistration
177   -- ^ The capabilities that the server has dynamically registered with us so
178   -- far
179   , curProgressSessions :: Set.Set ProgressToken
180   }
181
182 class Monad m => HasState s m where
183   get :: m s
184
185   put :: s -> m ()
186
187   modify :: (s -> s) -> m ()
188   modify f = get >>= put . f
189
190   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
191   modifyM f = get >>= f >>= put
192
193 instance HasState SessionState Session where
194   get = Session (lift State.get)
195   put = Session . lift . State.put
196
197 instance Monad m => HasState s (StateT s m) where
198   get = State.get
199   put = State.put
200
201 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
202  where
203   get = lift get
204   put = lift . put
205
206 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
207  where
208   get = lift get
209   put = lift . put
210
211 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
212 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
213   where
214     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
215
216     handler (Unexpected "ConduitParser.empty") = do
217       lastMsg <- fromJust . lastReceivedMessage <$> get
218       name <- getParserName
219       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
220
221     handler e = throw e
222
223     chanSource = do
224       msg <- liftIO $ readChan (messageChan context)
225       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
226         yield msg
227       chanSource
228
229     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
230     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
231     isLogNotification _ = False
232
233     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
234     watchdog = Conduit.awaitForever $ \msg -> do
235       curId <- getCurTimeoutId
236       case msg of
237         ServerMessage sMsg -> yield sMsg
238         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
239
240 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
241 -- It also does not automatically send initialize and exit messages.
242 runSession' :: Handle -- ^ Server in
243             -> Handle -- ^ Server out
244             -> Maybe ProcessHandle -- ^ Server process
245             -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
246             -> SessionConfig
247             -> ClientCapabilities
248             -> FilePath -- ^ Root directory
249             -> Session () -- ^ To exit the Server properly
250             -> Session a
251             -> IO a
252 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
253   absRootDir <- canonicalizePath rootDir
254
255   hSetBuffering serverIn  NoBuffering
256   hSetBuffering serverOut NoBuffering
257   -- This is required to make sure that we don’t get any
258   -- newline conversion or weird encoding issues.
259   hSetBinaryMode serverIn True
260   hSetBinaryMode serverOut True
261
262   reqMap <- newMVar newRequestMap
263   messageChan <- newChan
264   timeoutIdVar <- newMVar 0
265   initRsp <- newEmptyMVar
266
267   mainThreadId <- myThreadId
268
269   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
270       initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
271       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
272
273       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
274       serverListenerLauncher =
275         forkIO $ catch (serverHandler serverOut context) errorHandler
276       msgTimeoutMs = messageTimeout config * 10^6
277       serverAndListenerFinalizer tid = do
278         let cleanup
279               | Just sp <- mServerProc = do
280                   -- Give the server some time to exit cleanly
281                   -- It makes the server hangs in windows so we have to avoid it
282 #ifndef mingw32_HOST_OS
283                   timeout msgTimeoutMs (waitForProcess sp)
284 #endif
285                   cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
286               | otherwise = pure ()
287         finally (timeout msgTimeoutMs (runSession' exitServer))
288                 -- Make sure to kill the listener first, before closing
289                 -- handles etc via cleanupProcess
290                 (killThread tid >> cleanup)
291
292   (result, _) <- bracket serverListenerLauncher
293                          serverAndListenerFinalizer
294                          (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
295   return result
296
297 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
298 updateStateC = awaitForever $ \msg -> do
299   updateState msg
300   yield msg
301
302 -- extract Uri out from DocumentChange
303 -- didn't put this in `lsp-types` because TH was getting in the way
304 documentChangeUri :: DocumentChange -> Uri
305 documentChangeUri (InL x) = x ^. textDocument . uri
306 documentChangeUri (InR (InL x)) = x ^. uri
307 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
308 documentChangeUri (InR (InR (InR x))) = x ^. uri
309
310 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
311             => FromServerMessage -> m ()
312 updateState (FromServerMess SProgress req) = case req ^. params . value of
313   Begin _ ->
314     modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
315   End _ ->
316     modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
317   _ -> pure ()
318
319 -- Keep track of dynamic capability registration
320 updateState (FromServerMess SClientRegisterCapability req) = do
321   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
322   modify $ \s ->
323     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
324
325 updateState (FromServerMess SClientUnregisterCapability req) = do
326   let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
327   modify $ \s ->
328     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
329     in s { curDynCaps = newCurDynCaps }
330
331 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
332   let List diags = n ^. params . diagnostics
333       doc = n ^. params . uri
334   modify $ \s ->
335     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
336       in s { curDiagnostics = newDiags }
337
338 updateState (FromServerMess SWorkspaceApplyEdit r) = do
339
340   -- First, prefer the versioned documentChanges field
341   allChangeParams <- case r ^. params . edit . documentChanges of
342     Just (List cs) -> do
343       mapM_ (checkIfNeedsOpened . documentChangeUri) cs
344       return $ mapMaybe getParamsFromDocumentChange cs
345     -- Then fall back to the changes field
346     Nothing -> case r ^. params . edit . changes of
347       Just cs -> do
348         mapM_ checkIfNeedsOpened (HashMap.keys cs)
349         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
350       Nothing ->
351         error "WorkspaceEdit contains neither documentChanges nor changes!"
352
353   modifyM $ \s -> do
354     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
355     return $ s { vfs = newVFS }
356
357   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
358       mergedParams = map mergeParams groupedParams
359
360   -- TODO: Don't do this when replaying a session
361   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
362
363   -- Update VFS to new document versions
364   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
365       latestVersions = map ((^. textDocument) . last) sortedVersions
366       bumpedVersions = map (version . _Just +~ 1) latestVersions
367
368   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
369     modify $ \s ->
370       let oldVFS = vfs s
371           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
372           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
373       in s { vfs = newVFS }
374
375   where checkIfNeedsOpened uri = do
376           oldVFS <- vfs <$> get
377           ctx <- ask
378
379           -- if its not open, open it
380           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
381             let fp = fromJust $ uriToFilePath uri
382             contents <- liftIO $ T.readFile fp
383             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
384                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
385             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
386
387             modifyM $ \s -> do
388               let (newVFS,_) = openVFS (vfs s) msg
389               return $ s { vfs = newVFS }
390
391         getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
392         getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) = 
393           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
394             in DidChangeTextDocumentParams docId (List changeEvents)
395
396         getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
397         getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
398         getParamsFromDocumentChange _ = Nothing
399
400
401         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
402         -- where n is the current version
403         textDocumentVersions uri = do
404           m <- vfsMap . vfs <$> get
405           let curVer = fromMaybe 0 $
406                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
407           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
408
409         textDocumentEdits uri edits = do
410           vers <- textDocumentVersions uri
411           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
412
413         getChangeParams uri (List edits) = do 
414           map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
415
416         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
417         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
418                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
419 updateState _ = return ()
420
421 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
422 sendMessage msg = do
423   h <- serverIn <$> ask
424   logMsg LogClient msg
425   liftIO $ B.hPut h (addHeader $ encode msg)
426
427 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
428 -- after duration seconds. This will override the global timeout
429 -- for waiting for messages to arrive defined in 'SessionConfig'.
430 withTimeout :: Int -> Session a -> Session a
431 withTimeout duration f = do
432   chan <- asks messageChan
433   timeoutId <- getCurTimeoutId
434   modify $ \s -> s { overridingTimeout = True }
435   liftIO $ forkIO $ do
436     threadDelay (duration * 1000000)
437     writeChan chan (TimeoutMessage timeoutId)
438   res <- f
439   bumpTimeoutId timeoutId
440   modify $ \s -> s { overridingTimeout = False }
441   return res
442
443 data LogMsgType = LogServer | LogClient
444   deriving Eq
445
446 -- | Logs the message if the config specified it
447 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
448        => LogMsgType -> a -> m ()
449 logMsg t msg = do
450   shouldLog <- asks $ logMessages . config
451   shouldColor <- asks $ logColor . config
452   liftIO $ when shouldLog $ do
453     when shouldColor $ setSGR [SetColor Foreground Dull color]
454     putStrLn $ arrow ++ showPretty msg
455     when shouldColor $ setSGR [Reset]
456
457   where arrow
458           | t == LogServer  = "<-- "
459           | otherwise       = "--> "
460         color
461           | t == LogServer  = Magenta
462           | otherwise       = Cyan
463
464         showPretty = B.unpack . encodePretty