Abstracting documentChangeUri
[lsp-test.git] / src / Language / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE GADTs             #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
10
11 module Language.LSP.Test.Session
12   ( Session(..)
13   , SessionConfig(..)
14   , defaultConfig
15   , SessionMessage(..)
16   , SessionContext(..)
17   , SessionState(..)
18   , runSession'
19   , get
20   , put
21   , modify
22   , modifyM
23   , ask
24   , asks
25   , sendMessage
26   , updateState
27   , withTimeout
28   , getCurTimeoutId
29   , bumpTimeoutId
30   , logMsg
31   , LogMsgType(..)
32   , documentChangeUri
33   )
34
35 where
36
37 import Control.Applicative
38 import Control.Concurrent hiding (yield)
39 import Control.Exception
40 import Control.Lens hiding (List)
41 import Control.Monad
42 import Control.Monad.IO.Class
43 import Control.Monad.Except
44 #if __GLASGOW_HASKELL__ == 806
45 import Control.Monad.Fail
46 #endif
47 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
48 import qualified Control.Monad.Trans.Reader as Reader (ask)
49 import Control.Monad.Trans.State (StateT, runStateT)
50 import qualified Control.Monad.Trans.State as State
51 import qualified Data.ByteString.Lazy.Char8 as B
52 import Data.Aeson
53 import Data.Aeson.Encode.Pretty
54 import Data.Conduit as Conduit
55 import Data.Conduit.Parser as Parser
56 import Data.Default
57 import Data.Foldable
58 import Data.List
59 import qualified Data.Map as Map
60 import qualified Data.Text as T
61 import qualified Data.Text.IO as T
62 import qualified Data.HashMap.Strict as HashMap
63 import Data.Maybe
64 import Data.Function
65 import Language.LSP.Types.Capabilities
66 import Language.LSP.Types
67 import Language.LSP.Types.Lens
68 import qualified Language.LSP.Types.Lens as LSP
69 import Language.LSP.VFS
70 import Language.LSP.Test.Compat
71 import Language.LSP.Test.Decoding
72 import Language.LSP.Test.Exceptions
73 import System.Console.ANSI
74 import System.Directory
75 import System.IO
76 import System.Process (ProcessHandle())
77 #ifndef mingw32_HOST_OS
78 import System.Process (waitForProcess)
79 #endif
80 import System.Timeout
81
82 -- | A session representing one instance of launching and connecting to a server.
83 --
84 -- You can send and receive messages to the server within 'Session' via
85 -- 'Language.LSP.Test.message',
86 -- 'Language.LSP.Test.sendRequest' and
87 -- 'Language.LSP.Test.sendNotification'.
88
89 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
90   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
91
92 #if __GLASGOW_HASKELL__ >= 806
93 instance MonadFail Session where
94   fail s = do
95     lastMsg <- fromJust . lastReceivedMessage <$> get
96     liftIO $ throw (UnexpectedMessage s lastMsg)
97 #endif
98
99 -- | Stuff you can configure for a 'Session'.
100 data SessionConfig = SessionConfig
101   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
102   , logStdErr      :: Bool
103   -- ^ Redirect the server's stderr to this stdout, defaults to False.
104   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
105   , logMessages    :: Bool
106   -- ^ Trace the messages sent and received to stdout, defaults to False.
107   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
108   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
109   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
110   , ignoreLogNotifications :: Bool
111   -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
112   -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
113   --
114   -- @since 0.9.0.0
115   , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
116   -- ^ The initial workspace folders to send in the @initialize@ request.
117   -- Defaults to Nothing.
118   }
119
120 -- | The configuration used in 'Language.LSP.Test.runSession'.
121 defaultConfig :: SessionConfig
122 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
123
124 instance Default SessionConfig where
125   def = defaultConfig
126
127 data SessionMessage = ServerMessage FromServerMessage
128                     | TimeoutMessage Int
129   deriving Show
130
131 data SessionContext = SessionContext
132   {
133     serverIn :: Handle
134   , rootDir :: FilePath
135   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
136   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
137   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
138   , requestMap :: MVar RequestMap
139   , initRsp :: MVar (ResponseMessage Initialize)
140   , config :: SessionConfig
141   , sessionCapabilities :: ClientCapabilities
142   }
143
144 class Monad m => HasReader r m where
145   ask :: m r
146   asks :: (r -> b) -> m b
147   asks f = f <$> ask
148
149 instance HasReader SessionContext Session where
150   ask  = Session (lift $ lift Reader.ask)
151
152 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
153   ask = lift $ lift Reader.ask
154
155 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
156 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
157
158 -- Pass this the timeoutid you *were* waiting on
159 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
160 bumpTimeoutId prev = do
161   v <- asks curTimeoutId
162   -- when updating the curtimeoutid, account for the fact that something else
163   -- might have bumped the timeoutid in the meantime
164   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
165
166 data SessionState = SessionState
167   {
168     curReqId :: Int
169   , vfs :: VFS
170   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
171   , overridingTimeout :: Bool
172   -- ^ The last received message from the server.
173   -- Used for providing exception information
174   , lastReceivedMessage :: Maybe FromServerMessage
175   , curDynCaps :: Map.Map T.Text SomeRegistration
176   -- ^ The capabilities that the server has dynamically registered with us so
177   -- far
178   }
179
180 class Monad m => HasState s m where
181   get :: m s
182
183   put :: s -> m ()
184
185   modify :: (s -> s) -> m ()
186   modify f = get >>= put . f
187
188   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
189   modifyM f = get >>= f >>= put
190
191 instance HasState SessionState Session where
192   get = Session (lift State.get)
193   put = Session . lift . State.put
194
195 instance Monad m => HasState s (StateT s m) where
196   get = State.get
197   put = State.put
198
199 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
200  where
201   get = lift get
202   put = lift . put
203
204 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
205  where
206   get = lift get
207   put = lift . put
208
209 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
210 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
211   where
212     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
213
214     handler (Unexpected "ConduitParser.empty") = do
215       lastMsg <- fromJust . lastReceivedMessage <$> get
216       name <- getParserName
217       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
218
219     handler e = throw e
220
221     chanSource = do
222       msg <- liftIO $ readChan (messageChan context)
223       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
224         yield msg
225       chanSource
226
227     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
228     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
229     isLogNotification _ = False
230
231     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
232     watchdog = Conduit.awaitForever $ \msg -> do
233       curId <- getCurTimeoutId
234       case msg of
235         ServerMessage sMsg -> yield sMsg
236         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
237
238 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
239 -- It also does not automatically send initialize and exit messages.
240 runSession' :: Handle -- ^ Server in
241             -> Handle -- ^ Server out
242             -> Maybe ProcessHandle -- ^ Server process
243             -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
244             -> SessionConfig
245             -> ClientCapabilities
246             -> FilePath -- ^ Root directory
247             -> Session () -- ^ To exit the Server properly
248             -> Session a
249             -> IO a
250 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
251   absRootDir <- canonicalizePath rootDir
252
253   hSetBuffering serverIn  NoBuffering
254   hSetBuffering serverOut NoBuffering
255   -- This is required to make sure that we don’t get any
256   -- newline conversion or weird encoding issues.
257   hSetBinaryMode serverIn True
258   hSetBinaryMode serverOut True
259
260   reqMap <- newMVar newRequestMap
261   messageChan <- newChan
262   timeoutIdVar <- newMVar 0
263   initRsp <- newEmptyMVar
264
265   mainThreadId <- myThreadId
266
267   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
268       initState vfs = SessionState 0 vfs mempty False Nothing mempty
269       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
270
271       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
272       serverListenerLauncher =
273         forkIO $ catch (serverHandler serverOut context) errorHandler
274       msgTimeoutMs = messageTimeout config * 10^6
275       serverAndListenerFinalizer tid = do
276         let cleanup
277               | Just sp <- mServerProc = do
278                   -- Give the server some time to exit cleanly
279                   -- It makes the server hangs in windows so we have to avoid it
280 #ifndef mingw32_HOST_OS
281                   timeout msgTimeoutMs (waitForProcess sp)
282 #endif
283                   cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
284               | otherwise = pure ()
285         finally (timeout msgTimeoutMs (runSession' exitServer))
286                 -- Make sure to kill the listener first, before closing
287                 -- handles etc via cleanupProcess
288                 (killThread tid >> cleanup)
289
290   (result, _) <- bracket serverListenerLauncher
291                          serverAndListenerFinalizer
292                          (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
293   return result
294
295 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
296 updateStateC = awaitForever $ \msg -> do
297   updateState msg
298   yield msg
299
300 -- extract Uri out from DocumentChange
301 -- didn't put this in `lsp-types` because TH was getting in the way
302 documentChangeUri :: DocumentChange -> Uri
303 documentChangeUri (InL x) = x ^. textDocument . uri
304 documentChangeUri (InR (InL x)) = x ^. uri
305 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
306 documentChangeUri (InR (InR (InR x))) = x ^. uri
307
308 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
309             => FromServerMessage -> m ()
310
311 -- Keep track of dynamic capability registration
312 updateState (FromServerMess SClientRegisterCapability req) = do
313   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
314   modify $ \s ->
315     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
316
317 updateState (FromServerMess SClientUnregisterCapability req) = do
318   let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
319   modify $ \s ->
320     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
321     in s { curDynCaps = newCurDynCaps }
322
323 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
324   let List diags = n ^. params . diagnostics
325       doc = n ^. params . uri
326   modify $ \s ->
327     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
328       in s { curDiagnostics = newDiags }
329
330 updateState (FromServerMess SWorkspaceApplyEdit r) = do
331
332   -- First, prefer the versioned documentChanges field
333   allChangeParams <- case r ^. params . edit . documentChanges of
334     Just (List cs) -> do
335       mapM_ (checkIfNeedsOpened . documentChangeUri) cs
336       return $ mapMaybe getParamsFromDocumentChange cs
337     -- Then fall back to the changes field
338     Nothing -> case r ^. params . edit . changes of
339       Just cs -> do
340         mapM_ checkIfNeedsOpened (HashMap.keys cs)
341         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
342       Nothing ->
343         error "WorkspaceEdit contains neither documentChanges nor changes!"
344
345   modifyM $ \s -> do
346     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
347     return $ s { vfs = newVFS }
348
349   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
350       mergedParams = map mergeParams groupedParams
351
352   -- TODO: Don't do this when replaying a session
353   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
354
355   -- Update VFS to new document versions
356   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
357       latestVersions = map ((^. textDocument) . last) sortedVersions
358       bumpedVersions = map (version . _Just +~ 1) latestVersions
359
360   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
361     modify $ \s ->
362       let oldVFS = vfs s
363           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
364           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
365       in s { vfs = newVFS }
366
367   where checkIfNeedsOpened uri = do
368           oldVFS <- vfs <$> get
369           ctx <- ask
370
371           -- if its not open, open it
372           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
373             let fp = fromJust $ uriToFilePath uri
374             contents <- liftIO $ T.readFile fp
375             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
376                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
377             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
378
379             modifyM $ \s -> do
380               let (newVFS,_) = openVFS (vfs s) msg
381               return $ s { vfs = newVFS }
382
383         getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
384         getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) = 
385           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
386             in DidChangeTextDocumentParams docId (List changeEvents)
387
388         getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
389         getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
390         getParamsFromDocumentChange _ = Nothing
391
392
393         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
394         -- where n is the current version
395         textDocumentVersions uri = do
396           m <- vfsMap . vfs <$> get
397           let curVer = fromMaybe 0 $
398                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
399           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
400
401         textDocumentEdits uri edits = do
402           vers <- textDocumentVersions uri
403           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
404
405         getChangeParams uri (List edits) = do 
406           map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
407
408         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
409         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
410                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
411 updateState _ = return ()
412
413 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
414 sendMessage msg = do
415   h <- serverIn <$> ask
416   logMsg LogClient msg
417   liftIO $ B.hPut h (addHeader $ encode msg)
418
419 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
420 -- after duration seconds. This will override the global timeout
421 -- for waiting for messages to arrive defined in 'SessionConfig'.
422 withTimeout :: Int -> Session a -> Session a
423 withTimeout duration f = do
424   chan <- asks messageChan
425   timeoutId <- getCurTimeoutId
426   modify $ \s -> s { overridingTimeout = True }
427   liftIO $ forkIO $ do
428     threadDelay (duration * 1000000)
429     writeChan chan (TimeoutMessage timeoutId)
430   res <- f
431   bumpTimeoutId timeoutId
432   modify $ \s -> s { overridingTimeout = False }
433   return res
434
435 data LogMsgType = LogServer | LogClient
436   deriving Eq
437
438 -- | Logs the message if the config specified it
439 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
440        => LogMsgType -> a -> m ()
441 logMsg t msg = do
442   shouldLog <- asks $ logMessages . config
443   shouldColor <- asks $ logColor . config
444   liftIO $ when shouldLog $ do
445     when shouldColor $ setSGR [SetColor Foreground Dull color]
446     putStrLn $ arrow ++ showPretty msg
447     when shouldColor $ setSGR [Reset]
448
449   where arrow
450           | t == LogServer  = "<-- "
451           | otherwise       = "--> "
452         color
453           | t == LogServer  = Magenta
454           | otherwise       = Cyan
455
456         showPretty = B.unpack . encodePretty