36007391c35965676be1a5f87a337e33947a3afb
[lsp-test.git] / src / Language / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE GADTs             #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
10
11 module Language.LSP.Test.Session
12   ( Session(..)
13   , SessionConfig(..)
14   , defaultConfig
15   , SessionMessage(..)
16   , SessionContext(..)
17   , SessionState(..)
18   , runSession'
19   , get
20   , put
21   , modify
22   , modifyM
23   , ask
24   , asks
25   , sendMessage
26   , updateState
27   , withTimeout
28   , getCurTimeoutId
29   , bumpTimeoutId
30   , logMsg
31   , LogMsgType(..)
32   )
33
34 where
35
36 import Control.Applicative
37 import Control.Concurrent hiding (yield)
38 import Control.Exception
39 import Control.Lens hiding (List)
40 import Control.Monad
41 import Control.Monad.IO.Class
42 import Control.Monad.Except
43 #if __GLASGOW_HASKELL__ == 806
44 import Control.Monad.Fail
45 #endif
46 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
47 import qualified Control.Monad.Trans.Reader as Reader (ask)
48 import Control.Monad.Trans.State (StateT, runStateT)
49 import qualified Control.Monad.Trans.State as State
50 import qualified Data.ByteString.Lazy.Char8 as B
51 import Data.Aeson
52 import Data.Aeson.Encode.Pretty
53 import Data.Conduit as Conduit
54 import Data.Conduit.Parser as Parser
55 import Data.Default
56 import Data.Foldable
57 import Data.List
58 import qualified Data.Map as Map
59 import qualified Data.Text as T
60 import qualified Data.Text.IO as T
61 import qualified Data.HashMap.Strict as HashMap
62 import Data.Maybe
63 import Data.Function
64 import Language.LSP.Types.Capabilities
65 import Language.LSP.Types
66 import Language.LSP.Types.Lens
67 import qualified Language.LSP.Types.Lens as LSP
68 import Language.LSP.VFS
69 import Language.LSP.Test.Compat
70 import Language.LSP.Test.Decoding
71 import Language.LSP.Test.Exceptions
72 import System.Console.ANSI
73 import System.Directory
74 import System.IO
75 import System.Process (ProcessHandle())
76 #ifndef mingw32_HOST_OS
77 import System.Process (waitForProcess)
78 #endif
79 import System.Timeout
80
81 -- | A session representing one instance of launching and connecting to a server.
82 --
83 -- You can send and receive messages to the server within 'Session' via
84 -- 'Language.LSP.Test.message',
85 -- 'Language.LSP.Test.sendRequest' and
86 -- 'Language.LSP.Test.sendNotification'.
87
88 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
89   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
90
91 #if __GLASGOW_HASKELL__ >= 806
92 instance MonadFail Session where
93   fail s = do
94     lastMsg <- fromJust . lastReceivedMessage <$> get
95     liftIO $ throw (UnexpectedMessage s lastMsg)
96 #endif
97
98 -- | Stuff you can configure for a 'Session'.
99 data SessionConfig = SessionConfig
100   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
101   , logStdErr      :: Bool
102   -- ^ Redirect the server's stderr to this stdout, defaults to False.
103   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
104   , logMessages    :: Bool
105   -- ^ Trace the messages sent and received to stdout, defaults to False.
106   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
107   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
108   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
109   , ignoreLogNotifications :: Bool
110   -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
111   -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
112   --
113   -- @since 0.9.0.0
114   , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
115   -- ^ The initial workspace folders to send in the @initialize@ request.
116   -- Defaults to Nothing.
117   }
118
119 -- | The configuration used in 'Language.LSP.Test.runSession'.
120 defaultConfig :: SessionConfig
121 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
122
123 instance Default SessionConfig where
124   def = defaultConfig
125
126 data SessionMessage = ServerMessage FromServerMessage
127                     | TimeoutMessage Int
128   deriving Show
129
130 data SessionContext = SessionContext
131   {
132     serverIn :: Handle
133   , rootDir :: FilePath
134   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
135   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
136   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
137   , requestMap :: MVar RequestMap
138   , initRsp :: MVar (ResponseMessage Initialize)
139   , config :: SessionConfig
140   , sessionCapabilities :: ClientCapabilities
141   }
142
143 class Monad m => HasReader r m where
144   ask :: m r
145   asks :: (r -> b) -> m b
146   asks f = f <$> ask
147
148 instance HasReader SessionContext Session where
149   ask  = Session (lift $ lift Reader.ask)
150
151 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
152   ask = lift $ lift Reader.ask
153
154 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
155 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
156
157 -- Pass this the timeoutid you *were* waiting on
158 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
159 bumpTimeoutId prev = do
160   v <- asks curTimeoutId
161   -- when updating the curtimeoutid, account for the fact that something else
162   -- might have bumped the timeoutid in the meantime
163   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
164
165 data SessionState = SessionState
166   {
167     curReqId :: Int
168   , vfs :: VFS
169   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
170   , overridingTimeout :: Bool
171   -- ^ The last received message from the server.
172   -- Used for providing exception information
173   , lastReceivedMessage :: Maybe FromServerMessage
174   , curDynCaps :: Map.Map T.Text SomeRegistration
175   -- ^ The capabilities that the server has dynamically registered with us so
176   -- far
177   }
178
179 class Monad m => HasState s m where
180   get :: m s
181
182   put :: s -> m ()
183
184   modify :: (s -> s) -> m ()
185   modify f = get >>= put . f
186
187   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
188   modifyM f = get >>= f >>= put
189
190 instance HasState SessionState Session where
191   get = Session (lift State.get)
192   put = Session . lift . State.put
193
194 instance Monad m => HasState s (StateT s m) where
195   get = State.get
196   put = State.put
197
198 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
199  where
200   get = lift get
201   put = lift . put
202
203 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
204  where
205   get = lift get
206   put = lift . put
207
208 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
209 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
210   where
211     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
212
213     handler (Unexpected "ConduitParser.empty") = do
214       lastMsg <- fromJust . lastReceivedMessage <$> get
215       name <- getParserName
216       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
217
218     handler e = throw e
219
220     chanSource = do
221       msg <- liftIO $ readChan (messageChan context)
222       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
223         yield msg
224       chanSource
225
226     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
227     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
228     isLogNotification _ = False
229
230     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
231     watchdog = Conduit.awaitForever $ \msg -> do
232       curId <- getCurTimeoutId
233       case msg of
234         ServerMessage sMsg -> yield sMsg
235         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
236
237 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
238 -- It also does not automatically send initialize and exit messages.
239 runSession' :: Handle -- ^ Server in
240             -> Handle -- ^ Server out
241             -> Maybe ProcessHandle -- ^ Server process
242             -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
243             -> SessionConfig
244             -> ClientCapabilities
245             -> FilePath -- ^ Root directory
246             -> Session () -- ^ To exit the Server properly
247             -> Session a
248             -> IO a
249 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
250   absRootDir <- canonicalizePath rootDir
251
252   hSetBuffering serverIn  NoBuffering
253   hSetBuffering serverOut NoBuffering
254   -- This is required to make sure that we don’t get any
255   -- newline conversion or weird encoding issues.
256   hSetBinaryMode serverIn True
257   hSetBinaryMode serverOut True
258
259   reqMap <- newMVar newRequestMap
260   messageChan <- newChan
261   timeoutIdVar <- newMVar 0
262   initRsp <- newEmptyMVar
263
264   mainThreadId <- myThreadId
265
266   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
267       initState vfs = SessionState 0 vfs mempty False Nothing mempty
268       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
269
270       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
271       serverListenerLauncher =
272         forkIO $ catch (serverHandler serverOut context) errorHandler
273       msgTimeoutMs = messageTimeout config * 10^6
274       serverAndListenerFinalizer tid = do
275         let cleanup
276               | Just sp <- mServerProc = do
277                   -- Give the server some time to exit cleanly
278                   -- It makes the server hangs in windows so we have to avoid it
279 #ifndef mingw32_HOST_OS
280                   timeout msgTimeoutMs (waitForProcess sp)
281 #endif
282                   cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
283               | otherwise = pure ()
284         finally (timeout msgTimeoutMs (runSession' exitServer))
285                 -- Make sure to kill the listener first, before closing
286                 -- handles etc via cleanupProcess
287                 (killThread tid >> cleanup)
288
289   (result, _) <- bracket serverListenerLauncher
290                          serverAndListenerFinalizer
291                          (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
292   return result
293
294 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
295 updateStateC = awaitForever $ \msg -> do
296   updateState msg
297   yield msg
298
299
300 -- extract Uri out from DocumentChange
301 documentChangeUri :: DocumentChange -> Uri
302 documentChangeUri (InL x) = x ^. textDocument . uri
303 documentChangeUri (InR (InL x)) = x ^. uri
304 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
305 documentChangeUri (InR (InR (InR x))) = x ^. uri
306
307 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
308             => FromServerMessage -> m ()
309
310 -- Keep track of dynamic capability registration
311 updateState (FromServerMess SClientRegisterCapability req) = do
312   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
313   modify $ \s ->
314     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
315
316 updateState (FromServerMess SClientUnregisterCapability req) = do
317   let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
318   modify $ \s ->
319     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
320     in s { curDynCaps = newCurDynCaps }
321
322 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
323   let List diags = n ^. params . diagnostics
324       doc = n ^. params . uri
325   modify $ \s ->
326     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
327       in s { curDiagnostics = newDiags }
328
329 updateState (FromServerMess SWorkspaceApplyEdit r) = do
330
331   -- First, prefer the versioned documentChanges field
332   allChangeParams <- case r ^. params . edit . documentChanges of
333     Just (List cs) -> do
334       mapM_ (checkIfNeedsOpened . documentChangeUri) cs
335       return $ mapMaybe getParamsFromDocumentChange cs
336     -- Then fall back to the changes field
337     Nothing -> case r ^. params . edit . changes of
338       Just cs -> do
339         mapM_ checkIfNeedsOpened (HashMap.keys cs)
340         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
341       Nothing ->
342         error "WorkspaceEdit contains neither documentChanges nor changes!"
343
344   modifyM $ \s -> do
345     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
346     return $ s { vfs = newVFS }
347
348   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
349       mergedParams = map mergeParams groupedParams
350
351   -- TODO: Don't do this when replaying a session
352   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
353
354   -- Update VFS to new document versions
355   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
356       latestVersions = map ((^. textDocument) . last) sortedVersions
357       bumpedVersions = map (version . _Just +~ 1) latestVersions
358
359   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
360     modify $ \s ->
361       let oldVFS = vfs s
362           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
363           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
364       in s { vfs = newVFS }
365
366   where checkIfNeedsOpened uri = do
367           oldVFS <- vfs <$> get
368           ctx <- ask
369
370           -- if its not open, open it
371           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
372             let fp = fromJust $ uriToFilePath uri
373             contents <- liftIO $ T.readFile fp
374             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
375                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
376             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
377
378             modifyM $ \s -> do
379               let (newVFS,_) = openVFS (vfs s) msg
380               return $ s { vfs = newVFS }
381
382         getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
383         getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) = 
384           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
385             in DidChangeTextDocumentParams docId (List changeEvents)
386
387         getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
388         getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
389         getParamsFromDocumentChange _ = Nothing
390
391
392         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
393         -- where n is the current version
394         textDocumentVersions uri = do
395           m <- vfsMap . vfs <$> get
396           let curVer = fromMaybe 0 $
397                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
398           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
399
400         textDocumentEdits uri edits = do
401           vers <- textDocumentVersions uri
402           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
403
404         getChangeParams uri (List edits) = do 
405           map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
406
407         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
408         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
409                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
410 updateState _ = return ()
411
412 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
413 sendMessage msg = do
414   h <- serverIn <$> ask
415   logMsg LogClient msg
416   liftIO $ B.hPut h (addHeader $ encode msg)
417
418 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
419 -- after duration seconds. This will override the global timeout
420 -- for waiting for messages to arrive defined in 'SessionConfig'.
421 withTimeout :: Int -> Session a -> Session a
422 withTimeout duration f = do
423   chan <- asks messageChan
424   timeoutId <- getCurTimeoutId
425   modify $ \s -> s { overridingTimeout = True }
426   liftIO $ forkIO $ do
427     threadDelay (duration * 1000000)
428     writeChan chan (TimeoutMessage timeoutId)
429   res <- f
430   bumpTimeoutId timeoutId
431   modify $ \s -> s { overridingTimeout = False }
432   return res
433
434 data LogMsgType = LogServer | LogClient
435   deriving Eq
436
437 -- | Logs the message if the config specified it
438 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
439        => LogMsgType -> a -> m ()
440 logMsg t msg = do
441   shouldLog <- asks $ logMessages . config
442   shouldColor <- asks $ logColor . config
443   liftIO $ when shouldLog $ do
444     when shouldColor $ setSGR [SetColor Foreground Dull color]
445     putStrLn $ arrow ++ showPretty msg
446     when shouldColor $ setSGR [Reset]
447
448   where arrow
449           | t == LogServer  = "<-- "
450           | otherwise       = "--> "
451         color
452           | t == LogServer  = Magenta
453           | otherwise       = Cyan
454
455         showPretty = B.unpack . encodePretty