update and fill in `message`
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE GADTs             #-}
3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9
10 module Language.Haskell.LSP.Test.Session
11   ( Session(..)
12   , SessionConfig(..)
13   , defaultConfig
14   , SessionMessage(..)
15   , SessionContext(..)
16   , SessionState(..)
17   , runSessionWithHandles
18   , get
19   , put
20   , modify
21   , modifyM
22   , ask
23   , asks
24   , sendMessage
25   , updateState
26   , withTimeout
27   , getCurTimeoutId
28   , bumpTimeoutId
29   , logMsg
30   , LogMsgType(..)
31   )
32
33 where
34
35 import Control.Applicative
36 import Control.Concurrent hiding (yield)
37 import Control.Exception
38 import Control.Lens hiding (List)
39 import Control.Monad
40 import Control.Monad.IO.Class
41 import Control.Monad.Except
42 #if __GLASGOW_HASKELL__ == 806
43 import Control.Monad.Fail
44 #endif
45 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
46 import qualified Control.Monad.Trans.Reader as Reader (ask)
47 import Control.Monad.Trans.State (StateT, runStateT)
48 import qualified Control.Monad.Trans.State as State
49 import qualified Data.ByteString.Lazy.Char8 as B
50 import Data.Aeson
51 import Data.Aeson.Encode.Pretty
52 import Data.Conduit as Conduit
53 import Data.Conduit.Parser as Parser
54 import Data.Default
55 import Data.Foldable
56 import Data.List
57 import qualified Data.Map as Map
58 import qualified Data.Text as T
59 import qualified Data.Text.IO as T
60 import qualified Data.HashMap.Strict as HashMap
61 import Data.Maybe
62 import Data.Function
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
73 import System.IO
74 import System.Process (ProcessHandle())
75 #ifndef mingw32_HOST_OS
76 import System.Process (waitForProcess)
77 #endif
78 import System.Timeout
79
80 -- | A session representing one instance of launching and connecting to a server.
81 --
82 -- You can send and receive messages to the server within 'Session' via
83 -- 'Language.Haskell.LSP.Test.message',
84 -- 'Language.Haskell.LSP.Test.sendRequest' and
85 -- 'Language.Haskell.LSP.Test.sendNotification'.
86
87 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
88   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
89
90 #if __GLASGOW_HASKELL__ >= 806
91 instance MonadFail Session where
92   fail s = do
93     lastMsg <- fromJust . lastReceivedMessage <$> get
94     liftIO $ throw (UnexpectedMessage s lastMsg)
95 #endif
96
97 -- | Stuff you can configure for a 'Session'.
98 data SessionConfig = SessionConfig
99   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
100   , logStdErr      :: Bool
101   -- ^ Redirect the server's stderr to this stdout, defaults to False.
102   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
103   , logMessages    :: Bool
104   -- ^ Trace the messages sent and received to stdout, defaults to False.
105   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
106   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
107   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
108   , ignoreLogNotifications :: Bool
109   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
110   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
111   --
112   -- @since 0.9.0.0
113   }
114
115 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
116 defaultConfig :: SessionConfig
117 defaultConfig = SessionConfig 60 False False True Nothing False
118
119 instance Default SessionConfig where
120   def = defaultConfig
121
122 data SessionMessage = ServerMessage FromServerMessage
123                     | TimeoutMessage Int
124   deriving Show
125
126 data SessionContext = SessionContext
127   {
128     serverIn :: Handle
129   , rootDir :: FilePath
130   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
131   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
132   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
133   , requestMap :: MVar RequestMap
134   , initRsp :: MVar InitializeResponse
135   , config :: SessionConfig
136   , sessionCapabilities :: ClientCapabilities
137   }
138
139 class Monad m => HasReader r m where
140   ask :: m r
141   asks :: (r -> b) -> m b
142   asks f = f <$> ask
143
144 instance HasReader SessionContext Session where
145   ask  = Session (lift $ lift Reader.ask)
146
147 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
148   ask = lift $ lift Reader.ask
149
150 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
151 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
152
153 -- Pass this the timeoutid you *were* waiting on
154 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
155 bumpTimeoutId prev = do
156   v <- asks curTimeoutId
157   -- when updating the curtimeoutid, account for the fact that something else
158   -- might have bumped the timeoutid in the meantime
159   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
160
161 data SessionState = SessionState
162   {
163     curReqId :: Int
164   , vfs :: VFS
165   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
166   , overridingTimeout :: Bool
167   -- ^ The last received message from the server.
168   -- Used for providing exception information
169   , lastReceivedMessage :: Maybe FromServerMessage
170   , curDynCaps :: Map.Map T.Text SomeRegistration
171   -- ^ The capabilities that the server has dynamically registered with us so
172   -- far
173   }
174
175 class Monad m => HasState s m where
176   get :: m s
177
178   put :: s -> m ()
179
180   modify :: (s -> s) -> m ()
181   modify f = get >>= put . f
182
183   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
184   modifyM f = get >>= f >>= put
185
186 instance HasState SessionState Session where
187   get = Session (lift State.get)
188   put = Session . lift . State.put
189
190 instance Monad m => HasState s (StateT s m) where
191   get = State.get
192   put = State.put
193
194 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
195  where
196   get = lift get
197   put = lift . put
198
199 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
200  where
201   get = lift get
202   put = lift . put
203
204 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
205 runSession context state (Session session) = runReaderT (runStateT conduit state) context
206   where
207     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
208
209     handler (Unexpected "ConduitParser.empty") = do
210       lastMsg <- fromJust . lastReceivedMessage <$> get
211       name <- getParserName
212       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
213
214     handler e = throw e
215
216     chanSource = do
217       msg <- liftIO $ readChan (messageChan context)
218       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
219         yield msg
220       chanSource
221
222     isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
223     isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
224     isLogNotification _ = False
225
226     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
227     watchdog = Conduit.awaitForever $ \msg -> do
228       curId <- getCurTimeoutId
229       case msg of
230         ServerMessage sMsg -> yield sMsg
231         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
232
233 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
234 -- It also does not automatically send initialize and exit messages.
235 runSessionWithHandles :: Handle -- ^ Server in
236                       -> Handle -- ^ Server out
237                       -> ProcessHandle -- ^ Server process
238                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
239                       -> SessionConfig
240                       -> ClientCapabilities
241                       -> FilePath -- ^ Root directory
242                       -> Session () -- ^ To exit the Server properly
243                       -> Session a
244                       -> IO a
245 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
246   absRootDir <- canonicalizePath rootDir
247
248   hSetBuffering serverIn  NoBuffering
249   hSetBuffering serverOut NoBuffering
250   -- This is required to make sure that we don’t get any
251   -- newline conversion or weird encoding issues.
252   hSetBinaryMode serverIn True
253   hSetBinaryMode serverOut True
254
255   reqMap <- newMVar newRequestMap
256   messageChan <- newChan
257   timeoutIdVar <- newMVar 0
258   initRsp <- newEmptyMVar
259
260   mainThreadId <- myThreadId
261
262   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
263       initState vfs = SessionState 0 vfs mempty False Nothing mempty
264       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
265
266       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
267       serverListenerLauncher =
268         forkIO $ catch (serverHandler serverOut context) errorHandler
269       server = (Just serverIn, Just serverOut, Nothing, serverProc)
270       msgTimeoutMs = messageTimeout config * 10^6
271       serverAndListenerFinalizer tid = do
272         finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
273           -- Make sure to kill the listener first, before closing
274           -- handles etc via cleanupProcess
275           killThread tid
276           -- Give the server some time to exit cleanly
277           -- It makes the server hangs in windows so we have to avoid it
278 #ifndef mingw32_HOST_OS
279           timeout msgTimeoutMs (waitForProcess serverProc)
280 #endif
281           cleanupProcess server
282
283   (result, _) <- bracket serverListenerLauncher
284                          serverAndListenerFinalizer
285                          (const $ initVFS $ \vfs -> runSession context (initState vfs) session)
286   return result
287
288 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
289 updateStateC = awaitForever $ \msg -> do
290   updateState msg
291   yield msg
292
293 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
294             => FromServerMessage -> m ()
295
296 -- Keep track of dynamic capability registration
297 updateState (FromServerMess SClientRegisterCapability req) = do
298   let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
299   modify $ \s ->
300     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
301
302 updateState (FromServerMess SClientUnregisterCapability req) = do
303   let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
304   modify $ \s ->
305     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
306     in s { curDynCaps = newCurDynCaps }
307
308 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
309   let List diags = n ^. params . diagnostics
310       doc = n ^. params . uri
311   modify $ \s ->
312     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
313       in s { curDiagnostics = newDiags }
314
315 updateState (FromServerMess SWorkspaceApplyEdit r) = do
316
317   -- First, prefer the versioned documentChanges field
318   allChangeParams <- case r ^. params . edit . documentChanges of
319     Just (List cs) -> do
320       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
321       return $ map getParams cs
322     -- Then fall back to the changes field
323     Nothing -> case r ^. params . edit . changes of
324       Just cs -> do
325         mapM_ checkIfNeedsOpened (HashMap.keys cs)
326         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
327       Nothing ->
328         error "WorkspaceEdit contains neither documentChanges nor changes!"
329
330   modifyM $ \s -> do
331     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
332     return $ s { vfs = newVFS }
333
334   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
335       mergedParams = map mergeParams groupedParams
336
337   -- TODO: Don't do this when replaying a session
338   forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
339
340   -- Update VFS to new document versions
341   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
342       latestVersions = map ((^. textDocument) . last) sortedVersions
343       bumpedVersions = map (version . _Just +~ 1) latestVersions
344
345   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
346     modify $ \s ->
347       let oldVFS = vfs s
348           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
349           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
350       in s { vfs = newVFS }
351
352   where checkIfNeedsOpened uri = do
353           oldVFS <- vfs <$> get
354           ctx <- ask
355
356           -- if its not open, open it
357           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
358             let fp = fromJust $ uriToFilePath uri
359             contents <- liftIO $ T.readFile fp
360             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
361                 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
362             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
363
364             modifyM $ \s -> do
365               let (newVFS,_) = openVFS (vfs s) msg
366               return $ s { vfs = newVFS }
367
368         getParams (TextDocumentEdit docId (List edits)) =
369           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
370             in DidChangeTextDocumentParams docId (List changeEvents)
371
372         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
373         -- where n is the current version
374         textDocumentVersions uri = do
375           m <- vfsMap . vfs <$> get
376           let curVer = fromMaybe 0 $
377                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
378           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
379
380         textDocumentEdits uri edits = do
381           vers <- textDocumentVersions uri
382           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
383
384         getChangeParams uri (List edits) =
385           map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
386
387         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
388         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
389                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
390 updateState _ = return ()
391
392 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
393 sendMessage msg = do
394   h <- serverIn <$> ask
395   logMsg LogClient msg
396   liftIO $ B.hPut h (addHeader $ encode msg)
397
398 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
399 -- after duration seconds. This will override the global timeout
400 -- for waiting for messages to arrive defined in 'SessionConfig'.
401 withTimeout :: Int -> Session a -> Session a
402 withTimeout duration f = do
403   chan <- asks messageChan
404   timeoutId <- getCurTimeoutId
405   modify $ \s -> s { overridingTimeout = True }
406   liftIO $ forkIO $ do
407     threadDelay (duration * 1000000)
408     writeChan chan (TimeoutMessage timeoutId)
409   res <- f
410   bumpTimeoutId timeoutId
411   modify $ \s -> s { overridingTimeout = False }
412   return res
413
414 data LogMsgType = LogServer | LogClient
415   deriving Eq
416
417 -- | Logs the message if the config specified it
418 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
419        => LogMsgType -> a -> m ()
420 logMsg t msg = do
421   shouldLog <- asks $ logMessages . config
422   shouldColor <- asks $ logColor . config
423   liftIO $ when shouldLog $ do
424     when shouldColor $ setSGR [SetColor Foreground Dull color]
425     putStrLn $ arrow ++ showPretty msg
426     when shouldColor $ setSGR [Reset]
427
428   where arrow
429           | t == LogServer  = "<-- "
430           | otherwise       = "--> "
431         color
432           | t == LogServer  = Magenta
433           | otherwise       = Cyan
434
435         showPretty = B.unpack . encodePretty