Kill timeout thread
[lsp-test.git] / src / Language / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE BangPatterns      #-}
3 {-# LANGUAGE GADTs             #-}
4 {-# LANGUAGE OverloadedStrings #-}
5 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
6 {-# LANGUAGE FlexibleInstances #-}
7 {-# LANGUAGE MultiParamTypeClasses #-}
8 {-# LANGUAGE FlexibleContexts #-}
9 {-# LANGUAGE RankNTypes #-}
10 {-# LANGUAGE TypeInType #-}
11
12 module Language.LSP.Test.Session
13   ( Session(..)
14   , SessionConfig(..)
15   , defaultConfig
16   , SessionMessage(..)
17   , SessionContext(..)
18   , SessionState(..)
19   , runSession'
20   , get
21   , put
22   , modify
23   , modifyM
24   , ask
25   , asks
26   , sendMessage
27   , updateState
28   , withTimeout
29   , getCurTimeoutId
30   , bumpTimeoutId
31   , logMsg
32   , LogMsgType(..)
33   , documentChangeUri
34   )
35
36 where
37
38 import Control.Applicative
39 import Control.Concurrent hiding (yield)
40 import Control.Exception
41 import Control.Lens hiding (List)
42 import Control.Monad
43 import Control.Monad.IO.Class
44 import Control.Monad.Except
45 #if __GLASGOW_HASKELL__ == 806
46 import Control.Monad.Fail
47 #endif
48 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
49 import qualified Control.Monad.Trans.Reader as Reader (ask)
50 import Control.Monad.Trans.State (StateT, runStateT)
51 import qualified Control.Monad.Trans.State as State
52 import qualified Data.ByteString.Lazy.Char8 as B
53 import Data.Aeson
54 import Data.Aeson.Encode.Pretty
55 import Data.Conduit as Conduit
56 import Data.Conduit.Parser as Parser
57 import Data.Default
58 import Data.Foldable
59 import Data.List
60 import qualified Data.Map.Strict as Map
61 import qualified Data.Set as Set
62 import qualified Data.Text as T
63 import qualified Data.Text.IO as T
64 import qualified Data.HashMap.Strict as HashMap
65 import Data.Maybe
66 import Data.Function
67 import Language.LSP.Types.Capabilities
68 import Language.LSP.Types
69 import Language.LSP.Types.Lens
70 import qualified Language.LSP.Types.Lens as LSP
71 import Language.LSP.VFS
72 import Language.LSP.Test.Compat
73 import Language.LSP.Test.Decoding
74 import Language.LSP.Test.Exceptions
75 import System.Console.ANSI
76 import System.Directory
77 import System.IO
78 import System.Process (ProcessHandle())
79 #ifndef mingw32_HOST_OS
80 import System.Process (waitForProcess)
81 #endif
82 import System.Timeout
83 import Data.IORef
84
85 -- | A session representing one instance of launching and connecting to a server.
86 --
87 -- You can send and receive messages to the server within 'Session' via
88 -- 'Language.LSP.Test.message',
89 -- 'Language.LSP.Test.sendRequest' and
90 -- 'Language.LSP.Test.sendNotification'.
91
92 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
93   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
94
95 #if __GLASGOW_HASKELL__ >= 806
96 instance MonadFail Session where
97   fail s = do
98     lastMsg <- fromJust . lastReceivedMessage <$> get
99     liftIO $ throw (UnexpectedMessage s lastMsg)
100 #endif
101
102 -- | Stuff you can configure for a 'Session'.
103 data SessionConfig = SessionConfig
104   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
105   , logStdErr      :: Bool
106   -- ^ Redirect the server's stderr to this stdout, defaults to False.
107   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
108   , logMessages    :: Bool
109   -- ^ Trace the messages sent and received to stdout, defaults to False.
110   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
111   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
112   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
113   , ignoreLogNotifications :: Bool
114   -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
115   -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
116   --
117   -- @since 0.9.0.0
118   , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
119   -- ^ The initial workspace folders to send in the @initialize@ request.
120   -- Defaults to Nothing.
121   }
122
123 -- | The configuration used in 'Language.LSP.Test.runSession'.
124 defaultConfig :: SessionConfig
125 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
126
127 instance Default SessionConfig where
128   def = defaultConfig
129
130 data SessionMessage = ServerMessage FromServerMessage
131                     | TimeoutMessage Int
132   deriving Show
133
134 data SessionContext = SessionContext
135   {
136     serverIn :: Handle
137   , rootDir :: FilePath
138   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
139   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
140   , curTimeoutId :: IORef Int -- ^ The current timeout we are waiting on
141   , requestMap :: MVar RequestMap
142   , initRsp :: MVar (ResponseMessage Initialize)
143   , config :: SessionConfig
144   , sessionCapabilities :: ClientCapabilities
145   }
146
147 class Monad m => HasReader r m where
148   ask :: m r
149   asks :: (r -> b) -> m b
150   asks f = f <$> ask
151
152 instance HasReader SessionContext Session where
153   ask  = Session (lift $ lift Reader.ask)
154
155 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
156   ask = lift $ lift Reader.ask
157
158 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
159 getCurTimeoutId = asks curTimeoutId >>= liftIO . readIORef
160
161 -- Pass this the timeoutid you *were* waiting on
162 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
163 bumpTimeoutId prev = do
164   v <- asks curTimeoutId
165   -- when updating the curtimeoutid, account for the fact that something else
166   -- might have bumped the timeoutid in the meantime
167   liftIO $ atomicModifyIORef' v (\x -> (max x (prev + 1), ()))
168
169 data SessionState = SessionState
170   {
171     curReqId :: !Int
172   , vfs :: !VFS
173   , curDiagnostics :: !(Map.Map NormalizedUri [Diagnostic])
174   , overridingTimeout :: !Bool
175   -- ^ The last received message from the server.
176   -- Used for providing exception information
177   , lastReceivedMessage :: !(Maybe FromServerMessage)
178   , curDynCaps :: !(Map.Map T.Text SomeRegistration)
179   -- ^ The capabilities that the server has dynamically registered with us so
180   -- far
181   , curProgressSessions :: !(Set.Set ProgressToken)
182   }
183
184 class Monad m => HasState s m where
185   get :: m s
186
187   put :: s -> m ()
188
189   modify :: (s -> s) -> m ()
190   modify f = get >>= put . f
191
192   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
193   modifyM f = get >>= f >>= put
194
195 instance HasState SessionState Session where
196   get = Session (lift State.get)
197   put = Session . lift . State.put
198
199 instance Monad m => HasState s (StateT s m) where
200   get = State.get
201   put = State.put
202
203 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
204  where
205   get = lift get
206   put = lift . put
207
208 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
209  where
210   get = lift get
211   put = lift . put
212
213 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
214 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
215   where
216     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
217
218     handler (Unexpected "ConduitParser.empty") = do
219       lastMsg <- fromJust . lastReceivedMessage <$> get
220       name <- getParserName
221       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
222
223     handler e = throw e
224
225     chanSource = do
226       msg <- liftIO $ readChan (messageChan context)
227       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
228         yield msg
229       chanSource
230
231     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
232     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
233     isLogNotification _ = False
234
235     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
236     watchdog = Conduit.awaitForever $ \msg -> do
237       curId <- getCurTimeoutId
238       case msg of
239         ServerMessage sMsg -> yield sMsg
240         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
241
242 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
243 -- It also does not automatically send initialize and exit messages.
244 runSession' :: Handle -- ^ Server in
245             -> Handle -- ^ Server out
246             -> Maybe ProcessHandle -- ^ Server process
247             -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
248             -> SessionConfig
249             -> ClientCapabilities
250             -> FilePath -- ^ Root directory
251             -> Session () -- ^ To exit the Server properly
252             -> Session a
253             -> IO a
254 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
255   absRootDir <- canonicalizePath rootDir
256
257   hSetBuffering serverIn  NoBuffering
258   hSetBuffering serverOut NoBuffering
259   -- This is required to make sure that we don’t get any
260   -- newline conversion or weird encoding issues.
261   hSetBinaryMode serverIn True
262   hSetBinaryMode serverOut True
263
264   reqMap <- newMVar newRequestMap
265   messageChan <- newChan
266   timeoutIdVar <- newIORef 0
267   initRsp <- newEmptyMVar
268
269   mainThreadId <- myThreadId
270
271   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
272       initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
273       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
274
275       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
276       serverListenerLauncher =
277         forkIO $ catch (serverHandler serverOut context) errorHandler
278       msgTimeoutMs = messageTimeout config * 10^6
279       serverAndListenerFinalizer tid = do
280         let cleanup
281               | Just sp <- mServerProc = do
282                   -- Give the server some time to exit cleanly
283                   -- It makes the server hangs in windows so we have to avoid it
284 #ifndef mingw32_HOST_OS
285                   timeout msgTimeoutMs (waitForProcess sp)
286 #endif
287                   cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
288               | otherwise = pure ()
289         finally (timeout msgTimeoutMs (runSession' exitServer))
290                 -- Make sure to kill the listener first, before closing
291                 -- handles etc via cleanupProcess
292                 (killThread tid >> cleanup)
293
294   (result, _) <- bracket serverListenerLauncher
295                          serverAndListenerFinalizer
296                          (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
297   return result
298
299 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
300 updateStateC = awaitForever $ \msg -> do
301   updateState msg
302   respond msg
303   yield msg
304   where
305     respond :: (MonadIO m, HasReader SessionContext m) => FromServerMessage -> m ()
306     respond (FromServerMess SWindowWorkDoneProgressCreate req) =
307       sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
308     respond (FromServerMess SWorkspaceApplyEdit r) = do
309       sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
310     respond _ = pure ()
311
312
313 -- extract Uri out from DocumentChange
314 -- didn't put this in `lsp-types` because TH was getting in the way
315 documentChangeUri :: DocumentChange -> Uri
316 documentChangeUri (InL x) = x ^. textDocument . uri
317 documentChangeUri (InR (InL x)) = x ^. uri
318 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
319 documentChangeUri (InR (InR (InR x))) = x ^. uri
320
321 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
322             => FromServerMessage -> m ()
323 updateState (FromServerMess SProgress req) = case req ^. params . value of
324   Begin _ ->
325     modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
326   End _ ->
327     modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
328   _ -> pure ()
329
330 -- Keep track of dynamic capability registration
331 updateState (FromServerMess SClientRegisterCapability req) = do
332   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
333   modify $ \s ->
334     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
335
336 updateState (FromServerMess SClientUnregisterCapability req) = do
337   let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
338   modify $ \s ->
339     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
340     in s { curDynCaps = newCurDynCaps }
341
342 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
343   let List diags = n ^. params . diagnostics
344       doc = n ^. params . uri
345   modify $ \s ->
346     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
347       in s { curDiagnostics = newDiags }
348
349 updateState (FromServerMess SWorkspaceApplyEdit r) = do
350
351   -- First, prefer the versioned documentChanges field
352   allChangeParams <- case r ^. params . edit . documentChanges of
353     Just (List cs) -> do
354       mapM_ (checkIfNeedsOpened . documentChangeUri) cs
355       return $ mapMaybe getParamsFromDocumentChange cs
356     -- Then fall back to the changes field
357     Nothing -> case r ^. params . edit . changes of
358       Just cs -> do
359         mapM_ checkIfNeedsOpened (HashMap.keys cs)
360         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
361       Nothing ->
362         error "WorkspaceEdit contains neither documentChanges nor changes!"
363
364   modifyM $ \s -> do
365     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
366     return $ s { vfs = newVFS }
367
368   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
369       mergedParams = map mergeParams groupedParams
370
371   -- TODO: Don't do this when replaying a session
372   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
373
374   -- Update VFS to new document versions
375   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
376       latestVersions = map ((^. textDocument) . last) sortedVersions
377       bumpedVersions = map (version . _Just +~ 1) latestVersions
378
379   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
380     modify $ \s ->
381       let oldVFS = vfs s
382           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
383           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
384       in s { vfs = newVFS }
385
386   where checkIfNeedsOpened uri = do
387           oldVFS <- vfs <$> get
388           ctx <- ask
389
390           -- if its not open, open it
391           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
392             let fp = fromJust $ uriToFilePath uri
393             contents <- liftIO $ T.readFile fp
394             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
395                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
396             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
397
398             modifyM $ \s -> do
399               let (newVFS,_) = openVFS (vfs s) msg
400               return $ s { vfs = newVFS }
401
402         getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
403         getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) = 
404           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
405             in DidChangeTextDocumentParams docId (List changeEvents)
406
407         getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
408         getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
409         getParamsFromDocumentChange _ = Nothing
410
411
412         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
413         -- where n is the current version
414         textDocumentVersions uri = do
415           m <- vfsMap . vfs <$> get
416           let curVer = fromMaybe 0 $
417                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
418           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
419
420         textDocumentEdits uri edits = do
421           vers <- textDocumentVersions uri
422           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
423
424         getChangeParams uri (List edits) = do 
425           map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
426
427         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
428         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
429                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
430 updateState _ = return ()
431
432 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
433 sendMessage msg = do
434   h <- serverIn <$> ask
435   logMsg LogClient msg
436   liftIO $ B.hPut h (addHeader $ encode msg)
437
438 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
439 -- after duration seconds. This will override the global timeout
440 -- for waiting for messages to arrive defined in 'SessionConfig'.
441 withTimeout :: Int -> Session a -> Session a
442 withTimeout duration f = do
443   chan <- asks messageChan
444   timeoutId <- getCurTimeoutId
445   modify $ \s -> s { overridingTimeout = True }
446   tid <- liftIO $ forkIO $ do
447     threadDelay (duration * 1000000)
448     writeChan chan (TimeoutMessage timeoutId)
449   res <- f
450   liftIO $ killThread tid
451   bumpTimeoutId timeoutId
452   modify $ \s -> s { overridingTimeout = False }
453   return res
454
455 data LogMsgType = LogServer | LogClient
456   deriving Eq
457
458 -- | Logs the message if the config specified it
459 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
460        => LogMsgType -> a -> m ()
461 logMsg t msg = do
462   shouldLog <- asks $ logMessages . config
463   shouldColor <- asks $ logColor . config
464   liftIO $ when shouldLog $ do
465     when shouldColor $ setSGR [SetColor Foreground Dull color]
466     putStrLn $ arrow ++ showPretty msg
467     when shouldColor $ setSGR [Reset]
468
469   where arrow
470           | t == LogServer  = "<-- "
471           | otherwise       = "--> "
472         color
473           | t == LogServer  = Magenta
474           | otherwise       = Cyan
475
476         showPretty = B.unpack . encodePretty