Track changes to haskell-lsp
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , logMsg
27   , LogMsgType(..)
28   )
29
30 where
31
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
36 import Control.Monad
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
41 #endif
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
51 import Data.Default
52 import Data.Foldable
53 import Data.List
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
58 import Data.Maybe
59 import Data.Function
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
70 import System.IO
71 import System.Process (ProcessHandle())
72 import System.Timeout
73
74 -- | A session representing one instance of launching and connecting to a server.
75 --
76 -- You can send and receive messages to the server within 'Session' via
77 -- 'Language.Haskell.LSP.Test.message',
78 -- 'Language.Haskell.LSP.Test.sendRequest' and
79 -- 'Language.Haskell.LSP.Test.sendNotification'.
80
81 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
82   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
83
84 #if __GLASGOW_HASKELL__ >= 806
85 instance MonadFail Session where
86   fail s = do
87     lastMsg <- fromJust . lastReceivedMessage <$> get
88     liftIO $ throw (UnexpectedMessage s lastMsg)
89 #endif
90
91 -- | Stuff you can configure for a 'Session'.
92 data SessionConfig = SessionConfig
93   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
94   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
95   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
96   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
97   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
98   -- ^ Whether or not to ignore 'ShowMessageNotification' and 'LogMessageNotification', defaults to False.
99   -- @since 0.9.0.0
100   , ignoreLogNotifications :: Bool
101   }
102
103 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
104 defaultConfig :: SessionConfig
105 defaultConfig = SessionConfig 60 False False True Nothing False
106
107 instance Default SessionConfig where
108   def = defaultConfig
109
110 data SessionMessage = ServerMessage FromServerMessage
111                     | TimeoutMessage Int
112   deriving Show
113
114 data SessionContext = SessionContext
115   {
116     serverIn :: Handle
117   , rootDir :: FilePath
118   , messageChan :: Chan SessionMessage
119   , requestMap :: MVar RequestMap
120   , initRsp :: MVar InitializeResponse
121   , config :: SessionConfig
122   , sessionCapabilities :: ClientCapabilities
123   }
124
125 class Monad m => HasReader r m where
126   ask :: m r
127   asks :: (r -> b) -> m b
128   asks f = f <$> ask
129
130 instance HasReader SessionContext Session where
131   ask  = Session (lift $ lift Reader.ask)
132
133 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
134   ask = lift $ lift Reader.ask
135
136 data SessionState = SessionState
137   {
138     curReqId :: LspId
139   , vfs :: VFS
140   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
141   , curTimeoutId :: Int
142   , overridingTimeout :: Bool
143   -- ^ The last received message from the server.
144   -- Used for providing exception information
145   , lastReceivedMessage :: Maybe FromServerMessage
146   }
147
148 class Monad m => HasState s m where
149   get :: m s
150
151   put :: s -> m ()
152
153   modify :: (s -> s) -> m ()
154   modify f = get >>= put . f
155
156   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
157   modifyM f = get >>= f >>= put
158
159 instance HasState SessionState Session where
160   get = Session (lift State.get)
161   put = Session . lift . State.put
162
163 instance Monad m => HasState s (ConduitM a b (StateT s m))
164  where
165   get = lift State.get
166   put = lift . State.put
167
168 instance Monad m => HasState s (ConduitParser a (StateT s m))
169  where
170   get = lift State.get
171   put = lift . State.put
172
173 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
174 runSession context state (Session session) = runReaderT (runStateT conduit state) context
175   where
176     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
177
178     handler (Unexpected "ConduitParser.empty") = do
179       lastMsg <- fromJust . lastReceivedMessage <$> get
180       name <- getParserName
181       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
182
183     handler e = throw e
184
185     chanSource = do
186       msg <- liftIO $ readChan (messageChan context)
187       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
188         yield msg
189       chanSource
190
191     isLogNotification (ServerMessage (NotShowMessage _)) = True
192     isLogNotification (ServerMessage (NotLogMessage _)) = True
193     isLogNotification _ = False
194
195     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
196     watchdog = Conduit.awaitForever $ \msg -> do
197       curId <- curTimeoutId <$> get
198       case msg of
199         ServerMessage sMsg -> yield sMsg
200         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
201
202 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
203 -- It also does not automatically send initialize and exit messages.
204 runSessionWithHandles :: Handle -- ^ Server in
205                       -> Handle -- ^ Server out
206                       -> ProcessHandle -- ^ Server process
207                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
208                       -> SessionConfig
209                       -> ClientCapabilities
210                       -> FilePath -- ^ Root directory
211                       -> Session () -- ^ To exit the Server properly
212                       -> Session a
213                       -> IO a
214 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
215   absRootDir <- canonicalizePath rootDir
216
217   hSetBuffering serverIn  NoBuffering
218   hSetBuffering serverOut NoBuffering
219   -- This is required to make sure that we don’t get any
220   -- newline conversion or weird encoding issues.
221   hSetBinaryMode serverIn True
222   hSetBinaryMode serverOut True
223
224   reqMap <- newMVar newRequestMap
225   messageChan <- newChan
226   initRsp <- newEmptyMVar
227
228   mainThreadId <- myThreadId
229
230   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
231       initState vfs = SessionState (IdInt 0) vfs
232                                        mempty 0 False Nothing
233       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
234
235       errorHandler = throwTo mainThreadId :: SessionException -> IO()
236       serverListenerLauncher =
237         forkIO $ catch (serverHandler serverOut context) errorHandler
238       server = (Just serverIn, Just serverOut, Nothing, serverProc)
239       serverAndListenerFinalizer tid =
240         finally (timeout (messageTimeout config * 1000000)
241                          (runSession' exitServer))
242                 (cleanupProcess server >> killThread tid)
243
244   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
245                          (const $ runSession' session)
246   return result
247
248 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
249 updateStateC = awaitForever $ \msg -> do
250   updateState msg
251   yield msg
252
253 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
254             => FromServerMessage -> m ()
255 updateState (NotPublishDiagnostics n) = do
256   let List diags = n ^. params . diagnostics
257       doc = n ^. params . uri
258   modify (\s ->
259     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
260       in s { curDiagnostics = newDiags })
261
262 updateState (ReqApplyWorkspaceEdit r) = do
263
264   allChangeParams <- case r ^. params . edit . documentChanges of
265     Just (List cs) -> do
266       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
267       return $ map getParams cs
268     Nothing -> case r ^. params . edit . changes of
269       Just cs -> do
270         mapM_ checkIfNeedsOpened (HashMap.keys cs)
271         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
272       Nothing -> error "No changes!"
273
274   modifyM $ \s -> do
275     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
276     return $ s { vfs = newVFS }
277
278   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
279       mergedParams = map mergeParams groupedParams
280
281   -- TODO: Don't do this when replaying a session
282   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
283
284   -- Update VFS to new document versions
285   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
286       latestVersions = map ((^. textDocument) . last) sortedVersions
287       bumpedVersions = map (version . _Just +~ 1) latestVersions
288
289   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
290     modify $ \s ->
291       let oldVFS = vfs s
292           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
293           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
294       in s { vfs = newVFS }
295
296   where checkIfNeedsOpened uri = do
297           oldVFS <- vfs <$> get
298           ctx <- ask
299
300           -- if its not open, open it
301           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
302             let fp = fromJust $ uriToFilePath uri
303             contents <- liftIO $ T.readFile fp
304             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
305                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
306             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
307
308             modifyM $ \s -> do
309               let (newVFS,_) = openVFS (vfs s) msg
310               return $ s { vfs = newVFS }
311
312         getParams (TextDocumentEdit docId (List edits)) =
313           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
314             in DidChangeTextDocumentParams docId (List changeEvents)
315
316         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
317
318         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
319
320         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
321
322         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
323         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
324                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
325 updateState _ = return ()
326
327 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
328 sendMessage msg = do
329   h <- serverIn <$> ask
330   logMsg LogClient msg
331   liftIO $ B.hPut h (addHeader $ encode msg)
332
333 -- | Execute a block f that will throw a 'Timeout' exception
334 -- after duration seconds. This will override the global timeout
335 -- for waiting for messages to arrive defined in 'SessionConfig'.
336 withTimeout :: Int -> Session a -> Session a
337 withTimeout duration f = do
338   chan <- asks messageChan
339   timeoutId <- curTimeoutId <$> get
340   modify $ \s -> s { overridingTimeout = True }
341   liftIO $ forkIO $ do
342     threadDelay (duration * 1000000)
343     writeChan chan (TimeoutMessage timeoutId)
344   res <- f
345   modify $ \s -> s { curTimeoutId = timeoutId + 1,
346                      overridingTimeout = False
347                    }
348   return res
349
350 data LogMsgType = LogServer | LogClient
351   deriving Eq
352
353 -- | Logs the message if the config specified it
354 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
355        => LogMsgType -> a -> m ()
356 logMsg t msg = do
357   shouldLog <- asks $ logMessages . config
358   shouldColor <- asks $ logColor . config
359   liftIO $ when shouldLog $ do
360     when shouldColor $ setSGR [SetColor Foreground Dull color]
361     putStrLn $ arrow ++ showPretty msg
362     when shouldColor $ setSGR [Reset]
363
364   where arrow
365           | t == LogServer  = "<-- "
366           | otherwise       = "--> "
367         color
368           | t == LogServer  = Magenta
369           | otherwise       = Cyan
370
371         showPretty = B.unpack . encodePretty
372
373