a6474bb2ca1b0b9b9bd0964960c5a4670c48cadb
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE GADTs             #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9
10 module Language.Haskell.LSP.Test.Session
11   ( Session(..)
12   , SessionConfig(..)
13   , defaultConfig
14   , SessionMessage(..)
15   , SessionContext(..)
16   , SessionState(..)
17   , runSession'
18   , get
19   , put
20   , modify
21   , modifyM
22   , ask
23   , asks
24   , sendMessage
25   , updateState
26   , withTimeout
27   , getCurTimeoutId
28   , bumpTimeoutId
29   , logMsg
30   , LogMsgType(..)
31   )
32
33 where
34
35 import Control.Applicative
36 import Control.Concurrent hiding (yield)
37 import Control.Exception
38 import Control.Lens hiding (List)
39 import Control.Monad
40 import Control.Monad.IO.Class
41 import Control.Monad.Except
42 #if __GLASGOW_HASKELL__ == 806
43 import Control.Monad.Fail
44 #endif
45 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
46 import qualified Control.Monad.Trans.Reader as Reader (ask)
47 import Control.Monad.Trans.State (StateT, runStateT)
48 import qualified Control.Monad.Trans.State as State
49 import qualified Data.ByteString.Lazy.Char8 as B
50 import Data.Aeson
51 import Data.Aeson.Encode.Pretty
52 import Data.Conduit as Conduit
53 import Data.Conduit.Parser as Parser
54 import Data.Default
55 import Data.Foldable
56 import Data.List
57 import qualified Data.Map as Map
58 import qualified Data.Text as T
59 import qualified Data.Text.IO as T
60 import qualified Data.HashMap.Strict as HashMap
61 import Data.Maybe
62 import Data.Function
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
73 import System.IO
74 import System.Process (ProcessHandle())
75 #ifndef mingw32_HOST_OS
76 import System.Process (waitForProcess)
77 #endif
78 import System.Timeout
79
80 -- | A session representing one instance of launching and connecting to a server.
81 --
82 -- You can send and receive messages to the server within 'Session' via
83 -- 'Language.Haskell.LSP.Test.message',
84 -- 'Language.Haskell.LSP.Test.sendRequest' and
85 -- 'Language.Haskell.LSP.Test.sendNotification'.
86
87 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
88   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
89
90 #if __GLASGOW_HASKELL__ >= 806
91 instance MonadFail Session where
92   fail s = do
93     lastMsg <- fromJust . lastReceivedMessage <$> get
94     liftIO $ throw (UnexpectedMessage s lastMsg)
95 #endif
96
97 -- | Stuff you can configure for a 'Session'.
98 data SessionConfig = SessionConfig
99   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
100   , logStdErr      :: Bool
101   -- ^ Redirect the server's stderr to this stdout, defaults to False.
102   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
103   , logMessages    :: Bool
104   -- ^ Trace the messages sent and received to stdout, defaults to False.
105   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
106   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
107   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
108   , ignoreLogNotifications :: Bool
109   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
110   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
111   --
112   -- @since 0.9.0.0
113   , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
114   -- ^ The initial workspace folders to send in the @initialize@ request.
115   -- Defaults to Nothing.
116   }
117
118 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
119 defaultConfig :: SessionConfig
120 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
121
122 instance Default SessionConfig where
123   def = defaultConfig
124
125 data SessionMessage = ServerMessage FromServerMessage
126                     | TimeoutMessage Int
127   deriving Show
128
129 data SessionContext = SessionContext
130   {
131     serverIn :: Handle
132   , rootDir :: FilePath
133   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
134   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
135   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
136   , requestMap :: MVar RequestMap
137   , initRsp :: MVar InitializeResponse
138   , config :: SessionConfig
139   , sessionCapabilities :: ClientCapabilities
140   }
141
142 class Monad m => HasReader r m where
143   ask :: m r
144   asks :: (r -> b) -> m b
145   asks f = f <$> ask
146
147 instance HasReader SessionContext Session where
148   ask  = Session (lift $ lift Reader.ask)
149
150 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
151   ask = lift $ lift Reader.ask
152
153 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
154 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
155
156 -- Pass this the timeoutid you *were* waiting on
157 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
158 bumpTimeoutId prev = do
159   v <- asks curTimeoutId
160   -- when updating the curtimeoutid, account for the fact that something else
161   -- might have bumped the timeoutid in the meantime
162   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
163
164 data SessionState = SessionState
165   {
166     curReqId :: Int
167   , vfs :: VFS
168   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
169   , overridingTimeout :: Bool
170   -- ^ The last received message from the server.
171   -- Used for providing exception information
172   , lastReceivedMessage :: Maybe FromServerMessage
173   , curDynCaps :: Map.Map T.Text SomeRegistration
174   -- ^ The capabilities that the server has dynamically registered with us so
175   -- far
176   }
177
178 class Monad m => HasState s m where
179   get :: m s
180
181   put :: s -> m ()
182
183   modify :: (s -> s) -> m ()
184   modify f = get >>= put . f
185
186   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
187   modifyM f = get >>= f >>= put
188
189 instance HasState SessionState Session where
190   get = Session (lift State.get)
191   put = Session . lift . State.put
192
193 instance Monad m => HasState s (StateT s m) where
194   get = State.get
195   put = State.put
196
197 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
198  where
199   get = lift get
200   put = lift . put
201
202 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
203  where
204   get = lift get
205   put = lift . put
206
207 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
208 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
209   where
210     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
211
212     handler (Unexpected "ConduitParser.empty") = do
213       lastMsg <- fromJust . lastReceivedMessage <$> get
214       name <- getParserName
215       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
216
217     handler e = throw e
218
219     chanSource = do
220       msg <- liftIO $ readChan (messageChan context)
221       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
222         yield msg
223       chanSource
224
225     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
226     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
227     isLogNotification _ = False
228
229     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
230     watchdog = Conduit.awaitForever $ \msg -> do
231       curId <- getCurTimeoutId
232       case msg of
233         ServerMessage sMsg -> yield sMsg
234         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
235
236 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
237 -- It also does not automatically send initialize and exit messages.
238 runSession' :: Handle -- ^ Server in
239             -> Handle -- ^ Server out
240             -> Maybe ProcessHandle -- ^ Server process
241             -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
242             -> SessionConfig
243             -> ClientCapabilities
244             -> FilePath -- ^ Root directory
245             -> Session () -- ^ To exit the Server properly
246             -> Session a
247             -> IO a
248 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
249   absRootDir <- canonicalizePath rootDir
250
251   hSetBuffering serverIn  NoBuffering
252   hSetBuffering serverOut NoBuffering
253   -- This is required to make sure that we don’t get any
254   -- newline conversion or weird encoding issues.
255   hSetBinaryMode serverIn True
256   hSetBinaryMode serverOut True
257
258   reqMap <- newMVar newRequestMap
259   messageChan <- newChan
260   timeoutIdVar <- newMVar 0
261   initRsp <- newEmptyMVar
262
263   mainThreadId <- myThreadId
264
265   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
266       initState vfs = SessionState 0 vfs mempty False Nothing mempty
267       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
268
269       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
270       serverListenerLauncher =
271         forkIO $ catch (serverHandler serverOut context) errorHandler
272       server = (Just serverIn, Just serverOut, Nothing, serverProc)
273       msgTimeoutMs = messageTimeout config * 10^6
274       serverAndListenerFinalizer tid = do
275         let cleanup
276               | Just sp <- mServerProc = cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
277               | otherwise = pure ()
278         finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
279                 -- Make sure to kill the listener first, before closing
280                 -- handles etc via cleanupProcess
281                 killThread tid
282                 -- Give the server some time to exit cleanly
283 #ifndef mingw32_HOST_OS
284                 timeout msgTimeoutMs (waitForProcess serverProc)
285 #endif
286                 cleanup
287
288   (result, _) <- bracket serverListenerLauncher
289                          serverAndListenerFinalizer
290                          (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
291   return result
292
293 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
294 updateStateC = awaitForever $ \msg -> do
295   updateState msg
296   yield msg
297
298 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
299             => FromServerMessage -> m ()
300
301 -- Keep track of dynamic capability registration
302 updateState (FromServerMess SClientRegisterCapability req) = do
303   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
304   modify $ \s ->
305     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
306
307 updateState (FromServerMess SClientUnregisterCapability req) = do
308   let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
309   modify $ \s ->
310     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
311     in s { curDynCaps = newCurDynCaps }
312
313 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
314   let List diags = n ^. params . diagnostics
315       doc = n ^. params . uri
316   modify $ \s ->
317     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
318       in s { curDiagnostics = newDiags }
319
320 updateState (FromServerMess SWorkspaceApplyEdit r) = do
321
322   -- First, prefer the versioned documentChanges field
323   allChangeParams <- case r ^. params . edit . documentChanges of
324     Just (List cs) -> do
325       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
326       return $ map getParams cs
327     -- Then fall back to the changes field
328     Nothing -> case r ^. params . edit . changes of
329       Just cs -> do
330         mapM_ checkIfNeedsOpened (HashMap.keys cs)
331         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
332       Nothing ->
333         error "WorkspaceEdit contains neither documentChanges nor changes!"
334
335   modifyM $ \s -> do
336     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
337     return $ s { vfs = newVFS }
338
339   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
340       mergedParams = map mergeParams groupedParams
341
342   -- TODO: Don't do this when replaying a session
343   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
344
345   -- Update VFS to new document versions
346   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
347       latestVersions = map ((^. textDocument) . last) sortedVersions
348       bumpedVersions = map (version . _Just +~ 1) latestVersions
349
350   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
351     modify $ \s ->
352       let oldVFS = vfs s
353           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
354           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
355       in s { vfs = newVFS }
356
357   where checkIfNeedsOpened uri = do
358           oldVFS <- vfs <$> get
359           ctx <- ask
360
361           -- if its not open, open it
362           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
363             let fp = fromJust $ uriToFilePath uri
364             contents <- liftIO $ T.readFile fp
365             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
366                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
367             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
368
369             modifyM $ \s -> do
370               let (newVFS,_) = openVFS (vfs s) msg
371               return $ s { vfs = newVFS }
372
373         getParams (TextDocumentEdit docId (List edits)) =
374           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
375             in DidChangeTextDocumentParams docId (List changeEvents)
376
377         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
378         -- where n is the current version
379         textDocumentVersions uri = do
380           m <- vfsMap . vfs <$> get
381           let curVer = fromMaybe 0 $
382                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
383           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
384
385         textDocumentEdits uri edits = do
386           vers <- textDocumentVersions uri
387           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
388
389         getChangeParams uri (List edits) =
390           map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
391
392         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
393         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
394                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
395 updateState _ = return ()
396
397 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
398 sendMessage msg = do
399   h <- serverIn <$> ask
400   logMsg LogClient msg
401   liftIO $ B.hPut h (addHeader $ encode msg)
402
403 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
404 -- after duration seconds. This will override the global timeout
405 -- for waiting for messages to arrive defined in 'SessionConfig'.
406 withTimeout :: Int -> Session a -> Session a
407 withTimeout duration f = do
408   chan <- asks messageChan
409   timeoutId <- getCurTimeoutId
410   modify $ \s -> s { overridingTimeout = True }
411   liftIO $ forkIO $ do
412     threadDelay (duration * 1000000)
413     writeChan chan (TimeoutMessage timeoutId)
414   res <- f
415   bumpTimeoutId timeoutId
416   modify $ \s -> s { overridingTimeout = False }
417   return res
418
419 data LogMsgType = LogServer | LogClient
420   deriving Eq
421
422 -- | Logs the message if the config specified it
423 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
424        => LogMsgType -> a -> m ()
425 logMsg t msg = do
426   shouldLog <- asks $ logMessages . config
427   shouldColor <- asks $ logColor . config
428   liftIO $ when shouldLog $ do
429     when shouldColor $ setSGR [SetColor Foreground Dull color]
430     putStrLn $ arrow ++ showPretty msg
431     when shouldColor $ setSGR [Reset]
432
433   where arrow
434           | t == LogServer  = "<-- "
435           | otherwise       = "--> "
436         color
437           | t == LogServer  = Magenta
438           | otherwise       = Cyan
439
440         showPretty = B.unpack . encodePretty