55055cdebfce545b65755a9c5b7f3c72f33feb3b
[lsp-test.git] / src / Language / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE GADTs             #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
10
11 module Language.LSP.Test.Session
12   ( Session(..)
13   , SessionConfig(..)
14   , defaultConfig
15   , SessionMessage(..)
16   , SessionContext(..)
17   , SessionState(..)
18   , runSession'
19   , get
20   , put
21   , modify
22   , modifyM
23   , ask
24   , asks
25   , sendMessage
26   , updateState
27   , withTimeout
28   , getCurTimeoutId
29   , bumpTimeoutId
30   , logMsg
31   , LogMsgType(..)
32   , documentChangeUri
33   )
34
35 where
36
37 import Control.Applicative
38 import Control.Concurrent hiding (yield)
39 import Control.Exception
40 import Control.Lens hiding (List)
41 import Control.Monad
42 import Control.Monad.IO.Class
43 import Control.Monad.Except
44 #if __GLASGOW_HASKELL__ == 806
45 import Control.Monad.Fail
46 #endif
47 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
48 import qualified Control.Monad.Trans.Reader as Reader (ask)
49 import Control.Monad.Trans.State (StateT, runStateT)
50 import qualified Control.Monad.Trans.State as State
51 import qualified Data.ByteString.Lazy.Char8 as B
52 import Data.Aeson
53 import Data.Aeson.Encode.Pretty
54 import Data.Conduit as Conduit
55 import Data.Conduit.Parser as Parser
56 import Data.Default
57 import Data.Foldable
58 import Data.List
59 import qualified Data.Map as Map
60 import qualified Data.Set as Set
61 import qualified Data.Text as T
62 import qualified Data.Text.IO as T
63 import qualified Data.HashMap.Strict as HashMap
64 import Data.Maybe
65 import Data.Function
66 import Language.LSP.Types.Capabilities
67 import Language.LSP.Types
68 import Language.LSP.Types.Lens
69 import qualified Language.LSP.Types.Lens as LSP
70 import Language.LSP.VFS
71 import Language.LSP.Test.Compat
72 import Language.LSP.Test.Decoding
73 import Language.LSP.Test.Exceptions
74 import System.Console.ANSI
75 import System.Directory
76 import System.IO
77 import System.Process (ProcessHandle())
78 #ifndef mingw32_HOST_OS
79 import System.Process (waitForProcess)
80 #endif
81 import System.Timeout
82
83 -- | A session representing one instance of launching and connecting to a server.
84 --
85 -- You can send and receive messages to the server within 'Session' via
86 -- 'Language.LSP.Test.message',
87 -- 'Language.LSP.Test.sendRequest' and
88 -- 'Language.LSP.Test.sendNotification'.
89
90 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
91   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
92
93 #if __GLASGOW_HASKELL__ >= 806
94 instance MonadFail Session where
95   fail s = do
96     lastMsg <- fromJust . lastReceivedMessage <$> get
97     liftIO $ throw (UnexpectedMessage s lastMsg)
98 #endif
99
100 -- | Stuff you can configure for a 'Session'.
101 data SessionConfig = SessionConfig
102   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
103   , logStdErr      :: Bool
104   -- ^ Redirect the server's stderr to this stdout, defaults to False.
105   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
106   , logMessages    :: Bool
107   -- ^ Trace the messages sent and received to stdout, defaults to False.
108   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
109   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
110   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
111   , ignoreLogNotifications :: Bool
112   -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
113   -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
114   --
115   -- @since 0.9.0.0
116   , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
117   -- ^ The initial workspace folders to send in the @initialize@ request.
118   -- Defaults to Nothing.
119   }
120
121 -- | The configuration used in 'Language.LSP.Test.runSession'.
122 defaultConfig :: SessionConfig
123 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
124
125 instance Default SessionConfig where
126   def = defaultConfig
127
128 data SessionMessage = ServerMessage FromServerMessage
129                     | TimeoutMessage Int
130   deriving Show
131
132 data SessionContext = SessionContext
133   {
134     serverIn :: Handle
135   , rootDir :: FilePath
136   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
137   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
138   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
139   , requestMap :: MVar RequestMap
140   , initRsp :: MVar (ResponseMessage Initialize)
141   , config :: SessionConfig
142   , sessionCapabilities :: ClientCapabilities
143   }
144
145 class Monad m => HasReader r m where
146   ask :: m r
147   asks :: (r -> b) -> m b
148   asks f = f <$> ask
149
150 instance HasReader SessionContext Session where
151   ask  = Session (lift $ lift Reader.ask)
152
153 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
154   ask = lift $ lift Reader.ask
155
156 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
157 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
158
159 -- Pass this the timeoutid you *were* waiting on
160 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
161 bumpTimeoutId prev = do
162   v <- asks curTimeoutId
163   -- when updating the curtimeoutid, account for the fact that something else
164   -- might have bumped the timeoutid in the meantime
165   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
166
167 data SessionState = SessionState
168   {
169     curReqId :: Int
170   , vfs :: VFS
171   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
172   , overridingTimeout :: Bool
173   -- ^ The last received message from the server.
174   -- Used for providing exception information
175   , lastReceivedMessage :: Maybe FromServerMessage
176   , curDynCaps :: Map.Map T.Text SomeRegistration
177   -- ^ The capabilities that the server has dynamically registered with us so
178   -- far
179   , curProgressSessions :: Set.Set ProgressToken
180   }
181
182 class Monad m => HasState s m where
183   get :: m s
184
185   put :: s -> m ()
186
187   modify :: (s -> s) -> m ()
188   modify f = get >>= put . f
189
190   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
191   modifyM f = get >>= f >>= put
192
193 instance HasState SessionState Session where
194   get = Session (lift State.get)
195   put = Session . lift . State.put
196
197 instance Monad m => HasState s (StateT s m) where
198   get = State.get
199   put = State.put
200
201 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
202  where
203   get = lift get
204   put = lift . put
205
206 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
207  where
208   get = lift get
209   put = lift . put
210
211 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
212 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
213   where
214     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
215
216     handler (Unexpected "ConduitParser.empty") = do
217       lastMsg <- fromJust . lastReceivedMessage <$> get
218       name <- getParserName
219       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
220
221     handler e = throw e
222
223     chanSource = do
224       msg <- liftIO $ readChan (messageChan context)
225       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
226         yield msg
227       chanSource
228
229     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
230     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
231     isLogNotification _ = False
232
233     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
234     watchdog = Conduit.awaitForever $ \msg -> do
235       curId <- getCurTimeoutId
236       case msg of
237         ServerMessage sMsg -> yield sMsg
238         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
239
240 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
241 -- It also does not automatically send initialize and exit messages.
242 runSession' :: Handle -- ^ Server in
243             -> Handle -- ^ Server out
244             -> Maybe ProcessHandle -- ^ Server process
245             -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
246             -> SessionConfig
247             -> ClientCapabilities
248             -> FilePath -- ^ Root directory
249             -> Session () -- ^ To exit the Server properly
250             -> Session a
251             -> IO a
252 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
253   absRootDir <- canonicalizePath rootDir
254
255   hSetBuffering serverIn  NoBuffering
256   hSetBuffering serverOut NoBuffering
257   -- This is required to make sure that we don’t get any
258   -- newline conversion or weird encoding issues.
259   hSetBinaryMode serverIn True
260   hSetBinaryMode serverOut True
261
262   reqMap <- newMVar newRequestMap
263   messageChan <- newChan
264   timeoutIdVar <- newMVar 0
265   initRsp <- newEmptyMVar
266
267   mainThreadId <- myThreadId
268
269   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
270       initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
271       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
272
273       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
274       serverListenerLauncher =
275         forkIO $ catch (serverHandler serverOut context) errorHandler
276       msgTimeoutMs = messageTimeout config * 10^6
277       serverAndListenerFinalizer tid = do
278         let cleanup
279               | Just sp <- mServerProc = do
280                   -- Give the server some time to exit cleanly
281                   -- It makes the server hangs in windows so we have to avoid it
282 #ifndef mingw32_HOST_OS
283                   timeout msgTimeoutMs (waitForProcess sp)
284 #endif
285                   cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
286               | otherwise = pure ()
287         finally (timeout msgTimeoutMs (runSession' exitServer))
288                 -- Make sure to kill the listener first, before closing
289                 -- handles etc via cleanupProcess
290                 (killThread tid >> cleanup)
291
292   (result, _) <- bracket serverListenerLauncher
293                          serverAndListenerFinalizer
294                          (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
295   return result
296
297 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
298 updateStateC = awaitForever $ \msg -> do
299   updateState msg
300   respond msg
301   yield msg
302   where
303     respond :: (MonadIO m, HasReader SessionContext m) => FromServerMessage -> m ()
304     respond (FromServerMess SWindowWorkDoneProgressCreate req) =
305       sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
306     respond (FromServerMess SWorkspaceApplyEdit r) = do
307       sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
308     respond _ = pure ()
309
310
311 -- extract Uri out from DocumentChange
312 -- didn't put this in `lsp-types` because TH was getting in the way
313 documentChangeUri :: DocumentChange -> Uri
314 documentChangeUri (InL x) = x ^. textDocument . uri
315 documentChangeUri (InR (InL x)) = x ^. uri
316 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
317 documentChangeUri (InR (InR (InR x))) = x ^. uri
318
319 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
320             => FromServerMessage -> m ()
321 updateState (FromServerMess SProgress req) = case req ^. params . value of
322   Begin _ ->
323     modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
324   End _ ->
325     modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
326   _ -> pure ()
327
328 -- Keep track of dynamic capability registration
329 updateState (FromServerMess SClientRegisterCapability req) = do
330   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
331   modify $ \s ->
332     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
333
334 updateState (FromServerMess SClientUnregisterCapability req) = do
335   let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
336   modify $ \s ->
337     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
338     in s { curDynCaps = newCurDynCaps }
339
340 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
341   let List diags = n ^. params . diagnostics
342       doc = n ^. params . uri
343   modify $ \s ->
344     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
345       in s { curDiagnostics = newDiags }
346
347 updateState (FromServerMess SWorkspaceApplyEdit r) = do
348
349   -- First, prefer the versioned documentChanges field
350   allChangeParams <- case r ^. params . edit . documentChanges of
351     Just (List cs) -> do
352       mapM_ (checkIfNeedsOpened . documentChangeUri) cs
353       return $ mapMaybe getParamsFromDocumentChange cs
354     -- Then fall back to the changes field
355     Nothing -> case r ^. params . edit . changes of
356       Just cs -> do
357         mapM_ checkIfNeedsOpened (HashMap.keys cs)
358         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
359       Nothing ->
360         error "WorkspaceEdit contains neither documentChanges nor changes!"
361
362   modifyM $ \s -> do
363     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
364     return $ s { vfs = newVFS }
365
366   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
367       mergedParams = map mergeParams groupedParams
368
369   -- TODO: Don't do this when replaying a session
370   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
371
372   -- Update VFS to new document versions
373   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
374       latestVersions = map ((^. textDocument) . last) sortedVersions
375       bumpedVersions = map (version . _Just +~ 1) latestVersions
376
377   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
378     modify $ \s ->
379       let oldVFS = vfs s
380           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
381           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
382       in s { vfs = newVFS }
383
384   where checkIfNeedsOpened uri = do
385           oldVFS <- vfs <$> get
386           ctx <- ask
387
388           -- if its not open, open it
389           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
390             let fp = fromJust $ uriToFilePath uri
391             contents <- liftIO $ T.readFile fp
392             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
393                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
394             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
395
396             modifyM $ \s -> do
397               let (newVFS,_) = openVFS (vfs s) msg
398               return $ s { vfs = newVFS }
399
400         getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
401         getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) = 
402           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
403             in DidChangeTextDocumentParams docId (List changeEvents)
404
405         getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
406         getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
407         getParamsFromDocumentChange _ = Nothing
408
409
410         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
411         -- where n is the current version
412         textDocumentVersions uri = do
413           m <- vfsMap . vfs <$> get
414           let curVer = fromMaybe 0 $
415                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
416           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
417
418         textDocumentEdits uri edits = do
419           vers <- textDocumentVersions uri
420           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
421
422         getChangeParams uri (List edits) = do 
423           map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
424
425         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
426         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
427                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
428 updateState _ = return ()
429
430 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
431 sendMessage msg = do
432   h <- serverIn <$> ask
433   logMsg LogClient msg
434   liftIO $ B.hPut h (addHeader $ encode msg)
435
436 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
437 -- after duration seconds. This will override the global timeout
438 -- for waiting for messages to arrive defined in 'SessionConfig'.
439 withTimeout :: Int -> Session a -> Session a
440 withTimeout duration f = do
441   chan <- asks messageChan
442   timeoutId <- getCurTimeoutId
443   modify $ \s -> s { overridingTimeout = True }
444   liftIO $ forkIO $ do
445     threadDelay (duration * 1000000)
446     writeChan chan (TimeoutMessage timeoutId)
447   res <- f
448   bumpTimeoutId timeoutId
449   modify $ \s -> s { overridingTimeout = False }
450   return res
451
452 data LogMsgType = LogServer | LogClient
453   deriving Eq
454
455 -- | Logs the message if the config specified it
456 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
457        => LogMsgType -> a -> m ()
458 logMsg t msg = do
459   shouldLog <- asks $ logMessages . config
460   shouldColor <- asks $ logColor . config
461   liftIO $ when shouldLog $ do
462     when shouldColor $ setSGR [SetColor Foreground Dull color]
463     putStrLn $ arrow ++ showPretty msg
464     when shouldColor $ setSGR [Reset]
465
466   where arrow
467           | t == LogServer  = "<-- "
468           | otherwise       = "--> "
469         color
470           | t == LogServer  = Magenta
471           | otherwise       = Cyan
472
473         showPretty = B.unpack . encodePretty