track progress sessions
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , getCurTimeoutId
27   , bumpTimeoutId
28   , logMsg
29   , LogMsgType(..)
30   )
31
32 where
33
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
38 import Control.Monad
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
43 #endif
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
49 import Data.Aeson
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
53 import Data.Default
54 import Data.Foldable
55 import Data.List
56 import qualified Data.Map as Map
57 import qualified Data.Set as Set
58 import qualified Data.Text as T
59 import qualified Data.Text.IO as T
60 import qualified Data.HashMap.Strict as HashMap
61 import Data.Maybe
62 import Data.Function
63 import Language.Haskell.LSP.Messages
64 import Language.Haskell.LSP.Types.Capabilities
65 import Language.Haskell.LSP.Types
66 import Language.Haskell.LSP.Types.Lens
67 import qualified Language.Haskell.LSP.Types.Lens as LSP
68 import Language.Haskell.LSP.VFS
69 import Language.Haskell.LSP.Test.Compat
70 import Language.Haskell.LSP.Test.Decoding
71 import Language.Haskell.LSP.Test.Exceptions
72 import System.Console.ANSI
73 import System.Directory
74 import System.IO
75 import System.Process (ProcessHandle())
76 #ifndef mingw32_HOST_OS
77 import System.Process (waitForProcess)
78 #endif
79 import System.Timeout
80
81 -- | A session representing one instance of launching and connecting to a server.
82 --
83 -- You can send and receive messages to the server within 'Session' via
84 -- 'Language.Haskell.LSP.Test.message',
85 -- 'Language.Haskell.LSP.Test.sendRequest' and
86 -- 'Language.Haskell.LSP.Test.sendNotification'.
87
88 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
89   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
90
91 #if __GLASGOW_HASKELL__ >= 806
92 instance MonadFail Session where
93   fail s = do
94     lastMsg <- fromJust . lastReceivedMessage <$> get
95     liftIO $ throw (UnexpectedMessage s lastMsg)
96 #endif
97
98 -- | Stuff you can configure for a 'Session'.
99 data SessionConfig = SessionConfig
100   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
101   , logStdErr      :: Bool
102   -- ^ Redirect the server's stderr to this stdout, defaults to False.
103   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
104   , logMessages    :: Bool
105   -- ^ Trace the messages sent and received to stdout, defaults to False.
106   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
107   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
108   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
109   , ignoreLogNotifications :: Bool
110   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
111   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
112   --
113   -- @since 0.9.0.0
114   }
115
116 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
117 defaultConfig :: SessionConfig
118 defaultConfig = SessionConfig 60 False False True Nothing False
119
120 instance Default SessionConfig where
121   def = defaultConfig
122
123 data SessionMessage = ServerMessage FromServerMessage
124                     | TimeoutMessage Int
125   deriving Show
126
127 data SessionContext = SessionContext
128   {
129     serverIn :: Handle
130   , rootDir :: FilePath
131   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
132   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
133   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
134   , requestMap :: MVar RequestMap
135   , initRsp :: MVar InitializeResponse
136   , config :: SessionConfig
137   , sessionCapabilities :: ClientCapabilities
138   }
139
140 class Monad m => HasReader r m where
141   ask :: m r
142   asks :: (r -> b) -> m b
143   asks f = f <$> ask
144
145 instance HasReader SessionContext Session where
146   ask  = Session (lift $ lift Reader.ask)
147
148 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
149   ask = lift $ lift Reader.ask
150
151 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
152 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
153
154 -- Pass this the timeoutid you *were* waiting on
155 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
156 bumpTimeoutId prev = do
157   v <- asks curTimeoutId
158   -- when updating the curtimeoutid, account for the fact that something else
159   -- might have bumped the timeoutid in the meantime
160   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
161
162 data SessionState = SessionState
163   {
164     curReqId :: LspId
165   , vfs :: VFS
166   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
167   , overridingTimeout :: Bool
168   -- ^ The last received message from the server.
169   -- Used for providing exception information
170   , lastReceivedMessage :: Maybe FromServerMessage
171   , curDynCaps :: Map.Map T.Text Registration
172   -- ^ The capabilities that the server has dynamically registered with us so
173   -- far
174   , curProgressSessions :: Set.Set ProgressToken
175   }
176
177 class Monad m => HasState s m where
178   get :: m s
179
180   put :: s -> m ()
181
182   modify :: (s -> s) -> m ()
183   modify f = get >>= put . f
184
185   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
186   modifyM f = get >>= f >>= put
187
188 instance HasState SessionState Session where
189   get = Session (lift State.get)
190   put = Session . lift . State.put
191
192 instance Monad m => HasState s (StateT s m) where
193   get = State.get
194   put = State.put
195
196 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
197  where
198   get = lift get
199   put = lift . put
200
201 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
202  where
203   get = lift get
204   put = lift . put
205
206 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
207 runSession context state (Session session) = runReaderT (runStateT conduit state) context
208   where
209     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
210
211     handler (Unexpected "ConduitParser.empty") = do
212       lastMsg <- fromJust . lastReceivedMessage <$> get
213       name <- getParserName
214       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
215
216     handler e = throw e
217
218     chanSource = do
219       msg <- liftIO $ readChan (messageChan context)
220       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
221         yield msg
222       chanSource
223
224     isLogNotification (ServerMessage (NotShowMessage _)) = True
225     isLogNotification (ServerMessage (NotLogMessage _)) = True
226     isLogNotification _ = False
227
228     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
229     watchdog = Conduit.awaitForever $ \msg -> do
230       curId <- getCurTimeoutId
231       case msg of
232         ServerMessage sMsg -> yield sMsg
233         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
234
235 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
236 -- It also does not automatically send initialize and exit messages.
237 runSessionWithHandles :: Handle -- ^ Server in
238                       -> Handle -- ^ Server out
239                       -> ProcessHandle -- ^ Server process
240                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
241                       -> SessionConfig
242                       -> ClientCapabilities
243                       -> FilePath -- ^ Root directory
244                       -> Session () -- ^ To exit the Server properly
245                       -> Session a
246                       -> IO a
247 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
248   absRootDir <- canonicalizePath rootDir
249
250   hSetBuffering serverIn  NoBuffering
251   hSetBuffering serverOut NoBuffering
252   -- This is required to make sure that we don’t get any
253   -- newline conversion or weird encoding issues.
254   hSetBinaryMode serverIn True
255   hSetBinaryMode serverOut True
256
257   reqMap <- newMVar newRequestMap
258   messageChan <- newChan
259   timeoutIdVar <- newMVar 0
260   initRsp <- newEmptyMVar
261
262   mainThreadId <- myThreadId
263
264   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
265       initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty mempty
266       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
267
268       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
269       serverListenerLauncher =
270         forkIO $ catch (serverHandler serverOut context) errorHandler
271       server = (Just serverIn, Just serverOut, Nothing, serverProc)
272       msgTimeoutMs = messageTimeout config * 10^6
273       serverAndListenerFinalizer tid = do
274         finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
275           -- Make sure to kill the listener first, before closing
276           -- handles etc via cleanupProcess
277           killThread tid
278           -- Give the server some time to exit cleanly
279           -- It makes the server hangs in windows so we have to avoid it
280 #ifndef mingw32_HOST_OS
281           timeout msgTimeoutMs (waitForProcess serverProc)
282 #endif
283           cleanupProcess server
284
285   (result, _) <- bracket serverListenerLauncher
286                          serverAndListenerFinalizer
287                          (const $ runSession' session)
288   return result
289
290 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
291 updateStateC = awaitForever $ \msg -> do
292   updateState msg
293   yield msg
294
295 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
296             => FromServerMessage -> m ()
297 updateState (NotWorkDoneProgressBegin req) =
298   modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
299 updateState (NotWorkDoneProgressEnd req) =
300   modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
301
302 -- Keep track of dynamic capability registration
303 updateState (ReqRegisterCapability req) = do
304   let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
305   modify $ \s ->
306     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
307
308 updateState (ReqUnregisterCapability req) = do
309   let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
310   modify $ \s ->
311     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
312     in s { curDynCaps = newCurDynCaps }
313
314 updateState (NotPublishDiagnostics n) = do
315   let List diags = n ^. params . diagnostics
316       doc = n ^. params . uri
317   modify $ \s ->
318     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
319       in s { curDiagnostics = newDiags }
320
321 updateState (ReqApplyWorkspaceEdit r) = do
322
323   -- First, prefer the versioned documentChanges field
324   allChangeParams <- case r ^. params . edit . documentChanges of
325     Just (List cs) -> do
326       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
327       return $ map getParams cs
328     -- Then fall back to the changes field
329     Nothing -> case r ^. params . edit . changes of
330       Just cs -> do
331         mapM_ checkIfNeedsOpened (HashMap.keys cs)
332         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
333       Nothing ->
334         error "WorkspaceEdit contains neither documentChanges nor changes!"
335
336   modifyM $ \s -> do
337     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
338     return $ s { vfs = newVFS }
339
340   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
341       mergedParams = map mergeParams groupedParams
342
343   -- TODO: Don't do this when replaying a session
344   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
345
346   -- Update VFS to new document versions
347   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
348       latestVersions = map ((^. textDocument) . last) sortedVersions
349       bumpedVersions = map (version . _Just +~ 1) latestVersions
350
351   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
352     modify $ \s ->
353       let oldVFS = vfs s
354           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
355           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
356       in s { vfs = newVFS }
357
358   where checkIfNeedsOpened uri = do
359           oldVFS <- vfs <$> get
360           ctx <- ask
361
362           -- if its not open, open it
363           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
364             let fp = fromJust $ uriToFilePath uri
365             contents <- liftIO $ T.readFile fp
366             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
367                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
368             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
369
370             modifyM $ \s -> do
371               let (newVFS,_) = openVFS (vfs s) msg
372               return $ s { vfs = newVFS }
373
374         getParams (TextDocumentEdit docId (List edits)) =
375           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
376             in DidChangeTextDocumentParams docId (List changeEvents)
377
378         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
379         -- where n is the current version
380         textDocumentVersions uri = do
381           m <- vfsMap . vfs <$> get
382           let curVer = fromMaybe 0 $
383                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
384           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
385
386         textDocumentEdits uri edits = do
387           vers <- textDocumentVersions uri
388           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
389
390         getChangeParams uri (List edits) =
391           map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
392
393         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
394         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
395                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
396 updateState _ = return ()
397
398 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
399 sendMessage msg = do
400   h <- serverIn <$> ask
401   logMsg LogClient msg
402   liftIO $ B.hPut h (addHeader $ encode msg)
403
404 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
405 -- after duration seconds. This will override the global timeout
406 -- for waiting for messages to arrive defined in 'SessionConfig'.
407 withTimeout :: Int -> Session a -> Session a
408 withTimeout duration f = do
409   chan <- asks messageChan
410   timeoutId <- getCurTimeoutId
411   modify $ \s -> s { overridingTimeout = True }
412   liftIO $ forkIO $ do
413     threadDelay (duration * 1000000)
414     writeChan chan (TimeoutMessage timeoutId)
415   res <- f
416   bumpTimeoutId timeoutId
417   modify $ \s -> s { overridingTimeout = False }
418   return res
419
420 data LogMsgType = LogServer | LogClient
421   deriving Eq
422
423 -- | Logs the message if the config specified it
424 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
425        => LogMsgType -> a -> m ()
426 logMsg t msg = do
427   shouldLog <- asks $ logMessages . config
428   shouldColor <- asks $ logColor . config
429   liftIO $ when shouldLog $ do
430     when shouldColor $ setSGR [SetColor Foreground Dull color]
431     putStrLn $ arrow ++ showPretty msg
432     when shouldColor $ setSGR [Reset]
433
434   where arrow
435           | t == LogServer  = "<-- "
436           | otherwise       = "--> "
437         color
438           | t == LogServer  = Magenta
439           | otherwise       = Cyan
440
441         showPretty = B.unpack . encodePretty