Fix curtimeoutid being reset in the server exit handler
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , getCurTimeoutId
27   , bumpTimeoutId
28   , logMsg
29   , LogMsgType(..)
30   )
31
32 where
33
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
38 import Control.Monad
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
43 #endif
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
49 import Data.Aeson
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
53 import Data.Default
54 import Data.Foldable
55 import Data.List
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
60 import Data.Maybe
61 import Data.Function
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import Language.Haskell.LSP.VFS
67 import Language.Haskell.LSP.Test.Compat
68 import Language.Haskell.LSP.Test.Decoding
69 import Language.Haskell.LSP.Test.Exceptions
70 import System.Console.ANSI
71 import System.Directory
72 import System.IO
73 import System.Process (ProcessHandle())
74 import System.Timeout
75
76 -- | A session representing one instance of launching and connecting to a server.
77 --
78 -- You can send and receive messages to the server within 'Session' via
79 -- 'Language.Haskell.LSP.Test.message',
80 -- 'Language.Haskell.LSP.Test.sendRequest' and
81 -- 'Language.Haskell.LSP.Test.sendNotification'.
82
83 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
84   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
85
86 #if __GLASGOW_HASKELL__ >= 806
87 instance MonadFail Session where
88   fail s = do
89     lastMsg <- fromJust . lastReceivedMessage <$> get
90     liftIO $ throw (UnexpectedMessage s lastMsg)
91 #endif
92
93 -- | Stuff you can configure for a 'Session'.
94 data SessionConfig = SessionConfig
95   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
96   , logStdErr      :: Bool
97   -- ^ Redirect the server's stderr to this stdout, defaults to False.
98   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
99   , logMessages    :: Bool
100   -- ^ Trace the messages sent and received to stdout, defaults to False.
101   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
102   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
103   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
104   , ignoreLogNotifications :: Bool
105   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
106   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
107   --
108   -- @since 0.9.0.0
109   }
110
111 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
112 defaultConfig :: SessionConfig
113 defaultConfig = SessionConfig 60 False False True Nothing False
114
115 instance Default SessionConfig where
116   def = defaultConfig
117
118 data SessionMessage = ServerMessage FromServerMessage
119                     | TimeoutMessage Int
120   deriving Show
121
122 data SessionContext = SessionContext
123   {
124     serverIn :: Handle
125   , rootDir :: FilePath
126   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
127   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
128   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
129   , requestMap :: MVar RequestMap
130   , initRsp :: MVar InitializeResponse
131   , config :: SessionConfig
132   , sessionCapabilities :: ClientCapabilities
133   }
134
135 class Monad m => HasReader r m where
136   ask :: m r
137   asks :: (r -> b) -> m b
138   asks f = f <$> ask
139
140 instance HasReader SessionContext Session where
141   ask  = Session (lift $ lift Reader.ask)
142
143 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
144   ask = lift $ lift Reader.ask
145
146 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
147 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
148
149 -- Pass this the timeoutid you *were* waiting on
150 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
151 bumpTimeoutId prev = do
152   v <- asks curTimeoutId
153   -- when updating the curtimeoutid, account for the fact that something else
154   -- might have bumped the timeoutid in the meantime
155   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
156
157 data SessionState = SessionState
158   {
159     curReqId :: LspId
160   , vfs :: VFS
161   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
162   , overridingTimeout :: Bool
163   -- ^ The last received message from the server.
164   -- Used for providing exception information
165   , lastReceivedMessage :: Maybe FromServerMessage
166   }
167
168 class Monad m => HasState s m where
169   get :: m s
170
171   put :: s -> m ()
172
173   modify :: (s -> s) -> m ()
174   modify f = get >>= put . f
175
176   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
177   modifyM f = get >>= f >>= put
178
179 instance HasState SessionState Session where
180   get = Session (lift State.get)
181   put = Session . lift . State.put
182
183 instance Monad m => HasState s (StateT s m) where
184   get = State.get
185   put = State.put
186
187 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
188  where
189   get = lift get
190   put = lift . put
191
192 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
193  where
194   get = lift get
195   put = lift . put
196
197 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
198 runSession context state (Session session) = runReaderT (runStateT conduit state) context
199   where
200     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
201
202     handler (Unexpected "ConduitParser.empty") = do
203       lastMsg <- fromJust . lastReceivedMessage <$> get
204       name <- getParserName
205       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
206
207     handler e = throw e
208
209     chanSource = do
210       msg <- liftIO $ readChan (messageChan context)
211       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
212         yield msg
213       chanSource
214
215     isLogNotification (ServerMessage (NotShowMessage _)) = True
216     isLogNotification (ServerMessage (NotLogMessage _)) = True
217     isLogNotification _ = False
218
219     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
220     watchdog = Conduit.awaitForever $ \msg -> do
221       curId <- getCurTimeoutId
222       case msg of
223         ServerMessage sMsg -> yield sMsg
224         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
225
226 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
227 -- It also does not automatically send initialize and exit messages.
228 runSessionWithHandles :: Handle -- ^ Server in
229                       -> Handle -- ^ Server out
230                       -> ProcessHandle -- ^ Server process
231                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
232                       -> SessionConfig
233                       -> ClientCapabilities
234                       -> FilePath -- ^ Root directory
235                       -> Session () -- ^ To exit the Server properly
236                       -> Session a
237                       -> IO a
238 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
239   absRootDir <- canonicalizePath rootDir
240
241   hSetBuffering serverIn  NoBuffering
242   hSetBuffering serverOut NoBuffering
243   -- This is required to make sure that we don’t get any
244   -- newline conversion or weird encoding issues.
245   hSetBinaryMode serverIn True
246   hSetBinaryMode serverOut True
247
248   reqMap <- newMVar newRequestMap
249   messageChan <- newChan
250   timeoutIdVar <- newMVar 0
251   initRsp <- newEmptyMVar
252
253   mainThreadId <- myThreadId
254
255   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
256       initState vfs = SessionState (IdInt 0) vfs mempty False Nothing
257       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
258
259       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
260       serverListenerLauncher =
261         forkIO $ catch (serverHandler serverOut context) errorHandler
262       server = (Just serverIn, Just serverOut, Nothing, serverProc)
263       serverAndListenerFinalizer tid = do
264         finally (timeout (messageTimeout config * 1^6)
265                          (runSession' exitServer))
266                 (cleanupProcess server >> killThread tid)
267
268   (result, _) <- bracket serverListenerLauncher
269                          serverAndListenerFinalizer
270                          (const $ runSession' session)
271   return result
272
273 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
274 updateStateC = awaitForever $ \msg -> do
275   updateState msg
276   yield msg
277
278 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
279             => FromServerMessage -> m ()
280 updateState (NotPublishDiagnostics n) = do
281   let List diags = n ^. params . diagnostics
282       doc = n ^. params . uri
283   modify (\s ->
284     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
285       in s { curDiagnostics = newDiags })
286
287 updateState (ReqApplyWorkspaceEdit r) = do
288
289   allChangeParams <- case r ^. params . edit . documentChanges of
290     Just (List cs) -> do
291       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
292       return $ map getParams cs
293     Nothing -> case r ^. params . edit . changes of
294       Just cs -> do
295         mapM_ checkIfNeedsOpened (HashMap.keys cs)
296         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
297       Nothing -> error "No changes!"
298
299   modifyM $ \s -> do
300     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
301     return $ s { vfs = newVFS }
302
303   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
304       mergedParams = map mergeParams groupedParams
305
306   -- TODO: Don't do this when replaying a session
307   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
308
309   -- Update VFS to new document versions
310   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
311       latestVersions = map ((^. textDocument) . last) sortedVersions
312       bumpedVersions = map (version . _Just +~ 1) latestVersions
313
314   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
315     modify $ \s ->
316       let oldVFS = vfs s
317           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
318           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
319       in s { vfs = newVFS }
320
321   where checkIfNeedsOpened uri = do
322           oldVFS <- vfs <$> get
323           ctx <- ask
324
325           -- if its not open, open it
326           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
327             let fp = fromJust $ uriToFilePath uri
328             contents <- liftIO $ T.readFile fp
329             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
330                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
331             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
332
333             modifyM $ \s -> do
334               let (newVFS,_) = openVFS (vfs s) msg
335               return $ s { vfs = newVFS }
336
337         getParams (TextDocumentEdit docId (List edits)) =
338           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
339             in DidChangeTextDocumentParams docId (List changeEvents)
340
341         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
342
343         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
344
345         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
346
347         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
348         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
349                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
350 updateState _ = return ()
351
352 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
353 sendMessage msg = do
354   h <- serverIn <$> ask
355   logMsg LogClient msg
356   liftIO $ B.hPut h (addHeader $ encode msg)
357
358 -- | Execute a block f that will throw a 'Timeout' exception
359 -- after duration seconds. This will override the global timeout
360 -- for waiting for messages to arrive defined in 'SessionConfig'.
361 withTimeout :: Int -> Session a -> Session a
362 withTimeout duration f = do
363   chan <- asks messageChan
364   timeoutId <- getCurTimeoutId
365   modify $ \s -> s { overridingTimeout = True }
366   liftIO $ forkIO $ do
367     threadDelay (duration * 1000000)
368     writeChan chan (TimeoutMessage timeoutId)
369   res <- f
370   bumpTimeoutId timeoutId
371   modify $ \s -> s { overridingTimeout = False }
372   return res
373
374 data LogMsgType = LogServer | LogClient
375   deriving Eq
376
377 -- | Logs the message if the config specified it
378 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
379        => LogMsgType -> a -> m ()
380 logMsg t msg = do
381   shouldLog <- asks $ logMessages . config
382   shouldColor <- asks $ logColor . config
383   liftIO $ when shouldLog $ do
384     when shouldColor $ setSGR [SetColor Foreground Dull color]
385     putStrLn $ arrow ++ showPretty msg
386     when shouldColor $ setSGR [Reset]
387
388   where arrow
389           | t == LogServer  = "<-- "
390           | otherwise       = "--> "
391         color
392           | t == LogServer  = Magenta
393           | otherwise       = Cyan
394
395         showPretty = B.unpack . encodePretty
396
397