9076a8e5e02e9ed819479a7d10c2f3a2d4ffc21a
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , getCurTimeoutId
27   , bumpTimeoutId
28   , logMsg
29   , LogMsgType(..)
30   )
31
32 where
33
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
38 import Control.Monad
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
43 #endif
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
49 import Data.Aeson
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
53 import Data.Default
54 import Data.Foldable
55 import Data.List
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
60 import Data.Maybe
61 import Data.Function
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
73 import System.IO
74 import System.Process (ProcessHandle())
75 import System.Timeout
76
77 -- | A session representing one instance of launching and connecting to a server.
78 --
79 -- You can send and receive messages to the server within 'Session' via
80 -- 'Language.Haskell.LSP.Test.message',
81 -- 'Language.Haskell.LSP.Test.sendRequest' and
82 -- 'Language.Haskell.LSP.Test.sendNotification'.
83
84 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
85   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
86
87 #if __GLASGOW_HASKELL__ >= 806
88 instance MonadFail Session where
89   fail s = do
90     lastMsg <- fromJust . lastReceivedMessage <$> get
91     liftIO $ throw (UnexpectedMessage s lastMsg)
92 #endif
93
94 -- | Stuff you can configure for a 'Session'.
95 data SessionConfig = SessionConfig
96   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
97   , logStdErr      :: Bool
98   -- ^ Redirect the server's stderr to this stdout, defaults to False.
99   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
100   , logMessages    :: Bool
101   -- ^ Trace the messages sent and received to stdout, defaults to False.
102   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
103   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
104   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
105   , ignoreLogNotifications :: Bool
106   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
107   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
108   --
109   -- @since 0.9.0.0
110   }
111
112 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
113 defaultConfig :: SessionConfig
114 defaultConfig = SessionConfig 60 False False True Nothing False
115
116 instance Default SessionConfig where
117   def = defaultConfig
118
119 data SessionMessage = ServerMessage FromServerMessage
120                     | TimeoutMessage Int
121   deriving Show
122
123 data SessionContext = SessionContext
124   {
125     serverIn :: Handle
126   , rootDir :: FilePath
127   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
128   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
129   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
130   , requestMap :: MVar RequestMap
131   , initRsp :: MVar InitializeResponse
132   , config :: SessionConfig
133   , sessionCapabilities :: ClientCapabilities
134   }
135
136 class Monad m => HasReader r m where
137   ask :: m r
138   asks :: (r -> b) -> m b
139   asks f = f <$> ask
140
141 instance HasReader SessionContext Session where
142   ask  = Session (lift $ lift Reader.ask)
143
144 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
145   ask = lift $ lift Reader.ask
146
147 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
148 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
149
150 -- Pass this the timeoutid you *were* waiting on
151 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
152 bumpTimeoutId prev = do
153   v <- asks curTimeoutId
154   -- when updating the curtimeoutid, account for the fact that something else
155   -- might have bumped the timeoutid in the meantime
156   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
157
158 data SessionState = SessionState
159   {
160     curReqId :: LspId
161   , vfs :: VFS
162   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
163   , overridingTimeout :: Bool
164   -- ^ The last received message from the server.
165   -- Used for providing exception information
166   , lastReceivedMessage :: Maybe FromServerMessage
167   , curDynCaps :: Map.Map T.Text Registration
168   -- ^ The capabilities that the server has dynamically registered with us so
169   -- far
170   }
171
172 class Monad m => HasState s m where
173   get :: m s
174
175   put :: s -> m ()
176
177   modify :: (s -> s) -> m ()
178   modify f = get >>= put . f
179
180   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
181   modifyM f = get >>= f >>= put
182
183 instance HasState SessionState Session where
184   get = Session (lift State.get)
185   put = Session . lift . State.put
186
187 instance Monad m => HasState s (StateT s m) where
188   get = State.get
189   put = State.put
190
191 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
192  where
193   get = lift get
194   put = lift . put
195
196 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
197  where
198   get = lift get
199   put = lift . put
200
201 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
202 runSession context state (Session session) = runReaderT (runStateT conduit state) context
203   where
204     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
205
206     handler (Unexpected "ConduitParser.empty") = do
207       lastMsg <- fromJust . lastReceivedMessage <$> get
208       name <- getParserName
209       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
210
211     handler e = throw e
212
213     chanSource = do
214       msg <- liftIO $ readChan (messageChan context)
215       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
216         yield msg
217       chanSource
218
219     isLogNotification (ServerMessage (NotShowMessage _)) = True
220     isLogNotification (ServerMessage (NotLogMessage _)) = True
221     isLogNotification _ = False
222
223     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
224     watchdog = Conduit.awaitForever $ \msg -> do
225       curId <- getCurTimeoutId
226       case msg of
227         ServerMessage sMsg -> yield sMsg
228         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
229
230 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
231 -- It also does not automatically send initialize and exit messages.
232 runSessionWithHandles :: Handle -- ^ Server in
233                       -> Handle -- ^ Server out
234                       -> ProcessHandle -- ^ Server process
235                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
236                       -> SessionConfig
237                       -> ClientCapabilities
238                       -> FilePath -- ^ Root directory
239                       -> Session () -- ^ To exit the Server properly
240                       -> Session a
241                       -> IO a
242 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
243   absRootDir <- canonicalizePath rootDir
244
245   hSetBuffering serverIn  NoBuffering
246   hSetBuffering serverOut NoBuffering
247   -- This is required to make sure that we don’t get any
248   -- newline conversion or weird encoding issues.
249   hSetBinaryMode serverIn True
250   hSetBinaryMode serverOut True
251
252   reqMap <- newMVar newRequestMap
253   messageChan <- newChan
254   timeoutIdVar <- newMVar 0
255   initRsp <- newEmptyMVar
256
257   mainThreadId <- myThreadId
258
259   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
260       initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty
261       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
262
263       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
264       serverListenerLauncher =
265         forkIO $ catch (serverHandler serverOut context) errorHandler
266       server = (Just serverIn, Just serverOut, Nothing, serverProc)
267       serverAndListenerFinalizer tid = do
268         finally (timeout (messageTimeout config * 1^6)
269                          (runSession' exitServer))
270                 -- Make sure to kill the listener first, before closing
271                 -- handles etc via cleanupProcess
272                 (killThread tid >> cleanupProcess server)
273
274   (result, _) <- bracket serverListenerLauncher
275                          serverAndListenerFinalizer
276                          (const $ runSession' session)
277   return result
278
279 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
280 updateStateC = awaitForever $ \msg -> do
281   updateState msg
282   yield msg
283
284 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
285             => FromServerMessage -> m ()
286
287 -- Keep track of dynamic capability registration
288 updateState (ReqRegisterCapability req) = do
289   let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
290   modify $ \s ->
291     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
292
293 updateState (ReqUnregisterCapability req) = do
294   let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
295   modify $ \s ->
296     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
297     in s { curDynCaps = newCurDynCaps }
298
299 updateState (NotPublishDiagnostics n) = do
300   let List diags = n ^. params . diagnostics
301       doc = n ^. params . uri
302   modify $ \s ->
303     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
304       in s { curDiagnostics = newDiags }
305
306 updateState (ReqApplyWorkspaceEdit r) = do
307
308   allChangeParams <- case r ^. params . edit . documentChanges of
309     Just (List cs) -> do
310       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
311       return $ map getParams cs
312     Nothing -> case r ^. params . edit . changes of
313       Just cs -> do
314         mapM_ checkIfNeedsOpened (HashMap.keys cs)
315         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
316       Nothing -> error "No changes!"
317
318   modifyM $ \s -> do
319     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
320     return $ s { vfs = newVFS }
321
322   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
323       mergedParams = map mergeParams groupedParams
324
325   -- TODO: Don't do this when replaying a session
326   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
327
328   -- Update VFS to new document versions
329   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
330       latestVersions = map ((^. textDocument) . last) sortedVersions
331       bumpedVersions = map (version . _Just +~ 1) latestVersions
332
333   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
334     modify $ \s ->
335       let oldVFS = vfs s
336           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
337           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
338       in s { vfs = newVFS }
339
340   where checkIfNeedsOpened uri = do
341           oldVFS <- vfs <$> get
342           ctx <- ask
343
344           -- if its not open, open it
345           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
346             let fp = fromJust $ uriToFilePath uri
347             contents <- liftIO $ T.readFile fp
348             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
349                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
350             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
351
352             modifyM $ \s -> do
353               let (newVFS,_) = openVFS (vfs s) msg
354               return $ s { vfs = newVFS }
355
356         getParams (TextDocumentEdit docId (List edits)) =
357           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
358             in DidChangeTextDocumentParams docId (List changeEvents)
359
360         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
361
362         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
363
364         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
365
366         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
367         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
368                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
369 updateState _ = return ()
370
371 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
372 sendMessage msg = do
373   h <- serverIn <$> ask
374   logMsg LogClient msg
375   liftIO $ B.hPut h (addHeader $ encode msg)
376
377 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
378 -- after duration seconds. This will override the global timeout
379 -- for waiting for messages to arrive defined in 'SessionConfig'.
380 withTimeout :: Int -> Session a -> Session a
381 withTimeout duration f = do
382   chan <- asks messageChan
383   timeoutId <- getCurTimeoutId
384   modify $ \s -> s { overridingTimeout = True }
385   liftIO $ forkIO $ do
386     threadDelay (duration * 1000000)
387     writeChan chan (TimeoutMessage timeoutId)
388   res <- f
389   bumpTimeoutId timeoutId
390   modify $ \s -> s { overridingTimeout = False }
391   return res
392
393 data LogMsgType = LogServer | LogClient
394   deriving Eq
395
396 -- | Logs the message if the config specified it
397 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
398        => LogMsgType -> a -> m ()
399 logMsg t msg = do
400   shouldLog <- asks $ logMessages . config
401   shouldColor <- asks $ logColor . config
402   liftIO $ when shouldLog $ do
403     when shouldColor $ setSGR [SetColor Foreground Dull color]
404     putStrLn $ arrow ++ showPretty msg
405     when shouldColor $ setSGR [Reset]
406
407   where arrow
408           | t == LogServer  = "<-- "
409           | otherwise       = "--> "
410         color
411           | t == LogServer  = Magenta
412           | otherwise       = Cyan
413
414         showPretty = B.unpack . encodePretty