ac4c9ff066bd5a3479b4b4181014954acc0acfa5
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , logMsg
27   , LogMsgType(..)
28   )
29
30 where
31
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
36 import Control.Monad
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
41 #endif
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
51 import Data.Default
52 import Data.Foldable
53 import Data.List
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
58 import Data.Maybe
59 import Data.Function
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
70 import System.IO
71 import System.Process (ProcessHandle())
72 import System.Timeout
73
74 -- | A session representing one instance of launching and connecting to a server.
75 --
76 -- You can send and receive messages to the server within 'Session' via
77 -- 'Language.Haskell.LSP.Test.message',
78 -- 'Language.Haskell.LSP.Test.sendRequest' and
79 -- 'Language.Haskell.LSP.Test.sendNotification'.
80
81 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
82   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
83
84 #if __GLASGOW_HASKELL__ >= 806
85 instance MonadFail Session where
86   fail s = do
87     lastMsg <- fromJust . lastReceivedMessage <$> get
88     liftIO $ throw (UnexpectedMessage s lastMsg)
89 #endif
90
91 -- | Stuff you can configure for a 'Session'.
92 data SessionConfig = SessionConfig
93   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
94   , logStdErr      :: Bool
95   -- ^ Redirect the server's stderr to this stdout, defaults to False.
96   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
97   , logMessages    :: Bool
98   -- ^ Trace the messages sent and received to stdout, defaults to False.
99   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
100   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
101   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
102   , ignoreLogNotifications :: Bool
103   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
104   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
105   --
106   -- @since 0.9.0.0
107   }
108
109 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
110 defaultConfig :: SessionConfig
111 defaultConfig = SessionConfig 60 False False True Nothing False
112
113 instance Default SessionConfig where
114   def = defaultConfig
115
116 data SessionMessage = ServerMessage FromServerMessage
117                     | TimeoutMessage Int
118   deriving Show
119
120 data SessionContext = SessionContext
121   {
122     serverIn :: Handle
123   , rootDir :: FilePath
124   , messageChan :: Chan SessionMessage
125   , requestMap :: MVar RequestMap
126   , initRsp :: MVar InitializeResponse
127   , config :: SessionConfig
128   , sessionCapabilities :: ClientCapabilities
129   }
130
131 class Monad m => HasReader r m where
132   ask :: m r
133   asks :: (r -> b) -> m b
134   asks f = f <$> ask
135
136 instance HasReader SessionContext Session where
137   ask  = Session (lift $ lift Reader.ask)
138
139 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
140   ask = lift $ lift Reader.ask
141
142 data SessionState = SessionState
143   {
144     curReqId :: LspId
145   , vfs :: VFS
146   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
147   , curTimeoutId :: Int
148   , overridingTimeout :: Bool
149   -- ^ The last received message from the server.
150   -- Used for providing exception information
151   , lastReceivedMessage :: Maybe FromServerMessage
152   }
153
154 class Monad m => HasState s m where
155   get :: m s
156
157   put :: s -> m ()
158
159   modify :: (s -> s) -> m ()
160   modify f = get >>= put . f
161
162   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
163   modifyM f = get >>= f >>= put
164
165 instance HasState SessionState Session where
166   get = Session (lift State.get)
167   put = Session . lift . State.put
168
169 instance Monad m => HasState s (ConduitM a b (StateT s m))
170  where
171   get = lift State.get
172   put = lift . State.put
173
174 instance Monad m => HasState s (ConduitParser a (StateT s m))
175  where
176   get = lift State.get
177   put = lift . State.put
178
179 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
180 runSession context state (Session session) = runReaderT (runStateT conduit state) context
181   where
182     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
183
184     handler (Unexpected "ConduitParser.empty") = do
185       lastMsg <- fromJust . lastReceivedMessage <$> get
186       name <- getParserName
187       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
188
189     handler e = throw e
190
191     chanSource = do
192       msg <- liftIO $ readChan (messageChan context)
193       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
194         yield msg
195       chanSource
196
197     isLogNotification (ServerMessage (NotShowMessage _)) = True
198     isLogNotification (ServerMessage (NotLogMessage _)) = True
199     isLogNotification _ = False
200
201     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
202     watchdog = Conduit.awaitForever $ \msg -> do
203       curId <- curTimeoutId <$> get
204       case msg of
205         ServerMessage sMsg -> yield sMsg
206         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
207
208 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
209 -- It also does not automatically send initialize and exit messages.
210 runSessionWithHandles :: Handle -- ^ Server in
211                       -> Handle -- ^ Server out
212                       -> ProcessHandle -- ^ Server process
213                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
214                       -> SessionConfig
215                       -> ClientCapabilities
216                       -> FilePath -- ^ Root directory
217                       -> Session () -- ^ To exit the Server properly
218                       -> Session a
219                       -> IO a
220 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
221   absRootDir <- canonicalizePath rootDir
222
223   hSetBuffering serverIn  NoBuffering
224   hSetBuffering serverOut NoBuffering
225   -- This is required to make sure that we don’t get any
226   -- newline conversion or weird encoding issues.
227   hSetBinaryMode serverIn True
228   hSetBinaryMode serverOut True
229
230   reqMap <- newMVar newRequestMap
231   messageChan <- newChan
232   initRsp <- newEmptyMVar
233
234   mainThreadId <- myThreadId
235
236   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
237       initState vfs = SessionState (IdInt 0) vfs
238                                        mempty 0 False Nothing
239       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
240
241       errorHandler = throwTo mainThreadId :: SessionException -> IO()
242       serverListenerLauncher =
243         forkIO $ catch (serverHandler serverOut context) errorHandler
244       server = (Just serverIn, Just serverOut, Nothing, serverProc)
245       serverAndListenerFinalizer tid =
246         finally (timeout (messageTimeout config * 1000000)
247                          (runSession' exitServer))
248                 (cleanupProcess server >> killThread tid)
249
250   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
251                          (const $ runSession' session)
252   return result
253
254 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
255 updateStateC = awaitForever $ \msg -> do
256   updateState msg
257   yield msg
258
259 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
260             => FromServerMessage -> m ()
261 updateState (NotPublishDiagnostics n) = do
262   let List diags = n ^. params . diagnostics
263       doc = n ^. params . uri
264   modify (\s ->
265     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
266       in s { curDiagnostics = newDiags })
267
268 updateState (ReqApplyWorkspaceEdit r) = do
269
270   allChangeParams <- case r ^. params . edit . documentChanges of
271     Just (List cs) -> do
272       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
273       return $ map getParams cs
274     Nothing -> case r ^. params . edit . changes of
275       Just cs -> do
276         mapM_ checkIfNeedsOpened (HashMap.keys cs)
277         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
278       Nothing -> error "No changes!"
279
280   modifyM $ \s -> do
281     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
282     return $ s { vfs = newVFS }
283
284   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
285       mergedParams = map mergeParams groupedParams
286
287   -- TODO: Don't do this when replaying a session
288   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
289
290   -- Update VFS to new document versions
291   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
292       latestVersions = map ((^. textDocument) . last) sortedVersions
293       bumpedVersions = map (version . _Just +~ 1) latestVersions
294
295   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
296     modify $ \s ->
297       let oldVFS = vfs s
298           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
299           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
300       in s { vfs = newVFS }
301
302   where checkIfNeedsOpened uri = do
303           oldVFS <- vfs <$> get
304           ctx <- ask
305
306           -- if its not open, open it
307           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
308             let fp = fromJust $ uriToFilePath uri
309             contents <- liftIO $ T.readFile fp
310             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
311                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
312             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
313
314             modifyM $ \s -> do
315               let (newVFS,_) = openVFS (vfs s) msg
316               return $ s { vfs = newVFS }
317
318         getParams (TextDocumentEdit docId (List edits)) =
319           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
320             in DidChangeTextDocumentParams docId (List changeEvents)
321
322         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
323
324         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
325
326         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
327
328         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
329         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
330                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
331 updateState _ = return ()
332
333 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
334 sendMessage msg = do
335   h <- serverIn <$> ask
336   logMsg LogClient msg
337   liftIO $ B.hPut h (addHeader $ encode msg)
338
339 -- | Execute a block f that will throw a 'Timeout' exception
340 -- after duration seconds. This will override the global timeout
341 -- for waiting for messages to arrive defined in 'SessionConfig'.
342 withTimeout :: Int -> Session a -> Session a
343 withTimeout duration f = do
344   chan <- asks messageChan
345   timeoutId <- curTimeoutId <$> get
346   modify $ \s -> s { overridingTimeout = True }
347   liftIO $ forkIO $ do
348     threadDelay (duration * 1000000)
349     writeChan chan (TimeoutMessage timeoutId)
350   res <- f
351   modify $ \s -> s { curTimeoutId = timeoutId + 1,
352                      overridingTimeout = False
353                    }
354   return res
355
356 data LogMsgType = LogServer | LogClient
357   deriving Eq
358
359 -- | Logs the message if the config specified it
360 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
361        => LogMsgType -> a -> m ()
362 logMsg t msg = do
363   shouldLog <- asks $ logMessages . config
364   shouldColor <- asks $ logColor . config
365   liftIO $ when shouldLog $ do
366     when shouldColor $ setSGR [SetColor Foreground Dull color]
367     putStrLn $ arrow ++ showPretty msg
368     when shouldColor $ setSGR [Reset]
369
370   where arrow
371           | t == LogServer  = "<-- "
372           | otherwise       = "--> "
373         color
374           | t == LogServer  = Magenta
375           | otherwise       = Cyan
376
377         showPretty = B.unpack . encodePretty
378
379