21c008643e95869badbe46daa67a1fe9bda9209c
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , logMsg
27   , LogMsgType(..)
28   )
29
30 where
31
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
36 import Control.Monad
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
41 #endif
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
51 import Data.Default
52 import Data.Foldable
53 import Data.List
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
58 import Data.Maybe
59 import Data.Function
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
70 import System.IO
71 import System.Process (ProcessHandle())
72 import System.Timeout
73
74 -- | A session representing one instance of launching and connecting to a server.
75 --
76 -- You can send and receive messages to the server within 'Session' via
77 -- 'Language.Haskell.LSP.Test.message',
78 -- 'Language.Haskell.LSP.Test.sendRequest' and
79 -- 'Language.Haskell.LSP.Test.sendNotification'.
80
81 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
82   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
83
84 #if __GLASGOW_HASKELL__ >= 806
85 instance MonadFail Session where
86   fail s = do
87     lastMsg <- fromJust . lastReceivedMessage <$> get
88     liftIO $ throw (UnexpectedMessage s lastMsg)
89 #endif
90
91 -- | Stuff you can configure for a 'Session'.
92 data SessionConfig = SessionConfig
93   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
94   , logStdErr      :: Bool
95   -- ^ Redirect the server's stderr to this stdout, defaults to False.
96   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
97   , logMessages    :: Bool
98   -- ^ Trace the messages sent and received to stdout, defaults to False.
99   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
100   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
101   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
102   -- ^ Whether or not to ignore 'ShowMessageNotification' and 'LogMessageNotification', defaults to False.
103   -- @since 0.9.0.0
104   , ignoreLogNotifications :: Bool
105   }
106
107 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
108 defaultConfig :: SessionConfig
109 defaultConfig = SessionConfig 60 False False True Nothing False
110
111 instance Default SessionConfig where
112   def = defaultConfig
113
114 data SessionMessage = ServerMessage FromServerMessage
115                     | TimeoutMessage Int
116   deriving Show
117
118 data SessionContext = SessionContext
119   {
120     serverIn :: Handle
121   , rootDir :: FilePath
122   , messageChan :: Chan SessionMessage
123   , requestMap :: MVar RequestMap
124   , initRsp :: MVar InitializeResponse
125   , config :: SessionConfig
126   , sessionCapabilities :: ClientCapabilities
127   }
128
129 class Monad m => HasReader r m where
130   ask :: m r
131   asks :: (r -> b) -> m b
132   asks f = f <$> ask
133
134 instance HasReader SessionContext Session where
135   ask  = Session (lift $ lift Reader.ask)
136
137 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
138   ask = lift $ lift Reader.ask
139
140 data SessionState = SessionState
141   {
142     curReqId :: LspId
143   , vfs :: VFS
144   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
145   , curTimeoutId :: Int
146   , overridingTimeout :: Bool
147   -- ^ The last received message from the server.
148   -- Used for providing exception information
149   , lastReceivedMessage :: Maybe FromServerMessage
150   }
151
152 class Monad m => HasState s m where
153   get :: m s
154
155   put :: s -> m ()
156
157   modify :: (s -> s) -> m ()
158   modify f = get >>= put . f
159
160   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
161   modifyM f = get >>= f >>= put
162
163 instance HasState SessionState Session where
164   get = Session (lift State.get)
165   put = Session . lift . State.put
166
167 instance Monad m => HasState s (ConduitM a b (StateT s m))
168  where
169   get = lift State.get
170   put = lift . State.put
171
172 instance Monad m => HasState s (ConduitParser a (StateT s m))
173  where
174   get = lift State.get
175   put = lift . State.put
176
177 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
178 runSession context state (Session session) = runReaderT (runStateT conduit state) context
179   where
180     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
181
182     handler (Unexpected "ConduitParser.empty") = do
183       lastMsg <- fromJust . lastReceivedMessage <$> get
184       name <- getParserName
185       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
186
187     handler e = throw e
188
189     chanSource = do
190       msg <- liftIO $ readChan (messageChan context)
191       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
192         yield msg
193       chanSource
194
195     isLogNotification (ServerMessage (NotShowMessage _)) = True
196     isLogNotification (ServerMessage (NotLogMessage _)) = True
197     isLogNotification _ = False
198
199     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
200     watchdog = Conduit.awaitForever $ \msg -> do
201       curId <- curTimeoutId <$> get
202       case msg of
203         ServerMessage sMsg -> yield sMsg
204         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
205
206 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
207 -- It also does not automatically send initialize and exit messages.
208 runSessionWithHandles :: Handle -- ^ Server in
209                       -> Handle -- ^ Server out
210                       -> ProcessHandle -- ^ Server process
211                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
212                       -> SessionConfig
213                       -> ClientCapabilities
214                       -> FilePath -- ^ Root directory
215                       -> Session () -- ^ To exit the Server properly
216                       -> Session a
217                       -> IO a
218 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
219   absRootDir <- canonicalizePath rootDir
220
221   hSetBuffering serverIn  NoBuffering
222   hSetBuffering serverOut NoBuffering
223   -- This is required to make sure that we don’t get any
224   -- newline conversion or weird encoding issues.
225   hSetBinaryMode serverIn True
226   hSetBinaryMode serverOut True
227
228   reqMap <- newMVar newRequestMap
229   messageChan <- newChan
230   initRsp <- newEmptyMVar
231
232   mainThreadId <- myThreadId
233
234   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
235       initState vfs = SessionState (IdInt 0) vfs
236                                        mempty 0 False Nothing
237       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
238
239       errorHandler = throwTo mainThreadId :: SessionException -> IO()
240       serverListenerLauncher =
241         forkIO $ catch (serverHandler serverOut context) errorHandler
242       server = (Just serverIn, Just serverOut, Nothing, serverProc)
243       serverAndListenerFinalizer tid =
244         finally (timeout (messageTimeout config * 1000000)
245                          (runSession' exitServer))
246                 (cleanupProcess server >> killThread tid)
247
248   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
249                          (const $ runSession' session)
250   return result
251
252 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
253 updateStateC = awaitForever $ \msg -> do
254   updateState msg
255   yield msg
256
257 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
258             => FromServerMessage -> m ()
259 updateState (NotPublishDiagnostics n) = do
260   let List diags = n ^. params . diagnostics
261       doc = n ^. params . uri
262   modify (\s ->
263     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
264       in s { curDiagnostics = newDiags })
265
266 updateState (ReqApplyWorkspaceEdit r) = do
267
268   allChangeParams <- case r ^. params . edit . documentChanges of
269     Just (List cs) -> do
270       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
271       return $ map getParams cs
272     Nothing -> case r ^. params . edit . changes of
273       Just cs -> do
274         mapM_ checkIfNeedsOpened (HashMap.keys cs)
275         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
276       Nothing -> error "No changes!"
277
278   modifyM $ \s -> do
279     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
280     return $ s { vfs = newVFS }
281
282   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
283       mergedParams = map mergeParams groupedParams
284
285   -- TODO: Don't do this when replaying a session
286   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
287
288   -- Update VFS to new document versions
289   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
290       latestVersions = map ((^. textDocument) . last) sortedVersions
291       bumpedVersions = map (version . _Just +~ 1) latestVersions
292
293   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
294     modify $ \s ->
295       let oldVFS = vfs s
296           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
297           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
298       in s { vfs = newVFS }
299
300   where checkIfNeedsOpened uri = do
301           oldVFS <- vfs <$> get
302           ctx <- ask
303
304           -- if its not open, open it
305           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
306             let fp = fromJust $ uriToFilePath uri
307             contents <- liftIO $ T.readFile fp
308             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
309                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
310             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
311
312             modifyM $ \s -> do
313               let (newVFS,_) = openVFS (vfs s) msg
314               return $ s { vfs = newVFS }
315
316         getParams (TextDocumentEdit docId (List edits)) =
317           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
318             in DidChangeTextDocumentParams docId (List changeEvents)
319
320         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
321
322         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
323
324         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
325
326         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
327         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
328                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
329 updateState _ = return ()
330
331 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
332 sendMessage msg = do
333   h <- serverIn <$> ask
334   logMsg LogClient msg
335   liftIO $ B.hPut h (addHeader $ encode msg)
336
337 -- | Execute a block f that will throw a 'Timeout' exception
338 -- after duration seconds. This will override the global timeout
339 -- for waiting for messages to arrive defined in 'SessionConfig'.
340 withTimeout :: Int -> Session a -> Session a
341 withTimeout duration f = do
342   chan <- asks messageChan
343   timeoutId <- curTimeoutId <$> get
344   modify $ \s -> s { overridingTimeout = True }
345   liftIO $ forkIO $ do
346     threadDelay (duration * 1000000)
347     writeChan chan (TimeoutMessage timeoutId)
348   res <- f
349   modify $ \s -> s { curTimeoutId = timeoutId + 1,
350                      overridingTimeout = False
351                    }
352   return res
353
354 data LogMsgType = LogServer | LogClient
355   deriving Eq
356
357 -- | Logs the message if the config specified it
358 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
359        => LogMsgType -> a -> m ()
360 logMsg t msg = do
361   shouldLog <- asks $ logMessages . config
362   shouldColor <- asks $ logColor . config
363   liftIO $ when shouldLog $ do
364     when shouldColor $ setSGR [SetColor Foreground Dull color]
365     putStrLn $ arrow ++ showPretty msg
366     when shouldColor $ setSGR [Reset]
367
368   where arrow
369           | t == LogServer  = "<-- "
370           | otherwise       = "--> "
371         color
372           | t == LogServer  = Magenta
373           | otherwise       = Cyan
374
375         showPretty = B.unpack . encodePretty
376
377