Bubble up exceptions thrown on server listener thread
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68
69 -- | A session representing one instance of launching and connecting to a server.
70 -- 
71 -- You can send and receive messages to the server within 'Session' via 'getMessage',
72 -- 'sendRequest' and 'sendNotification'.
73 --
74
75 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
76
77 #if __GLASGOW_HASKELL__ >= 806
78 instance MonadFail Session where
79   fail s = do
80     lastMsg <- fromJust . lastReceivedMessage <$> get
81     liftIO $ throw (UnexpectedMessage s lastMsg)
82 #endif
83
84 -- | Stuff you can configure for a 'Session'.
85 data SessionConfig = SessionConfig
86   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
87   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
88   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
89   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
90   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
91   }
92
93 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
94 defaultConfig :: SessionConfig
95 defaultConfig = SessionConfig 60 False False True Nothing
96
97 instance Default SessionConfig where
98   def = defaultConfig
99
100 data SessionMessage = ServerMessage FromServerMessage
101                     | TimeoutMessage Int
102   deriving Show
103
104 data SessionContext = SessionContext
105   {
106     serverIn :: Handle
107   , rootDir :: FilePath
108   , messageChan :: Chan SessionMessage
109   , requestMap :: MVar RequestMap
110   , initRsp :: MVar InitializeResponse
111   , config :: SessionConfig
112   , sessionCapabilities :: ClientCapabilities
113   }
114
115 class Monad m => HasReader r m where
116   ask :: m r
117   asks :: (r -> b) -> m b
118   asks f = f <$> ask
119
120 instance Monad m => HasReader r (ParserStateReader a s r m) where
121   ask = lift $ lift Reader.ask
122
123 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
124   ask = lift $ lift Reader.ask
125
126 data SessionState = SessionState
127   {
128     curReqId :: LspId
129   , vfs :: VFS
130   , curDiagnostics :: Map.Map Uri [Diagnostic]
131   , curTimeoutId :: Int
132   , overridingTimeout :: Bool
133   -- ^ The last received message from the server.
134   -- Used for providing exception information
135   , lastReceivedMessage :: Maybe FromServerMessage
136   }
137
138 class Monad m => HasState s m where
139   get :: m s
140
141   put :: s -> m ()
142
143   modify :: (s -> s) -> m ()
144   modify f = get >>= put . f
145
146   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
147   modifyM f = get >>= f >>= put
148
149 instance Monad m => HasState s (ParserStateReader a s r m) where
150   get = lift State.get
151   put = lift . State.put
152
153 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
154  where
155   get = lift State.get
156   put = lift . State.put
157
158 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
159
160 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
161 runSession context state session = runReaderT (runStateT conduit state) context
162   where
163     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
164
165     handler (Unexpected "ConduitParser.empty") = do
166       lastMsg <- fromJust . lastReceivedMessage <$> get
167       name <- getParserName
168       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
169
170     handler e = throw e
171
172     chanSource = do
173       msg <- liftIO $ readChan (messageChan context)
174       yield msg
175       chanSource
176
177     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
178     watchdog = Conduit.awaitForever $ \msg -> do
179       curId <- curTimeoutId <$> get
180       case msg of
181         ServerMessage sMsg -> yield sMsg
182         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
183
184 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
185 -- It also does not automatically send initialize and exit messages.
186 runSessionWithHandles :: Handle -- ^ Server in
187                       -> Handle -- ^ Server out
188                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
189                       -> SessionConfig
190                       -> ClientCapabilities
191                       -> FilePath -- ^ Root directory
192                       -> Session a
193                       -> IO a
194 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
195   absRootDir <- canonicalizePath rootDir
196
197   hSetBuffering serverIn  NoBuffering
198   hSetBuffering serverOut NoBuffering
199
200   reqMap <- newMVar newRequestMap
201   messageChan <- newChan
202   initRsp <- newEmptyMVar
203
204   mainThreadId <- myThreadId
205
206   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
207       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
208       launchServerHandler = forkIO $ catch (serverHandler serverOut context)
209                                            (throwTo mainThreadId :: SessionException -> IO ())
210   (result, _) <- bracket launchServerHandler killThread $
211     const $ runSession context initState session
212
213   return result
214
215 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
216 updateStateC = awaitForever $ \msg -> do
217   updateState msg
218   yield msg
219
220 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
221 updateState (NotPublishDiagnostics n) = do
222   let List diags = n ^. params . diagnostics
223       doc = n ^. params . uri
224   modify (\s ->
225     let newDiags = Map.insert doc diags (curDiagnostics s)
226       in s { curDiagnostics = newDiags })
227
228 updateState (ReqApplyWorkspaceEdit r) = do
229
230   allChangeParams <- case r ^. params . edit . documentChanges of
231     Just (List cs) -> do
232       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
233       return $ map getParams cs
234     Nothing -> case r ^. params . edit . changes of
235       Just cs -> do
236         mapM_ checkIfNeedsOpened (HashMap.keys cs)
237         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
238       Nothing -> error "No changes!"
239
240   modifyM $ \s -> do
241     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
242     return $ s { vfs = newVFS }
243
244   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
245       mergedParams = map mergeParams groupedParams
246
247   -- TODO: Don't do this when replaying a session
248   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
249
250   -- Update VFS to new document versions
251   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
252       latestVersions = map ((^. textDocument) . last) sortedVersions
253       bumpedVersions = map (version . _Just +~ 1) latestVersions
254
255   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
256     modify $ \s ->
257       let oldVFS = vfs s
258           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
259           newVFS = Map.adjust update uri oldVFS
260       in s { vfs = newVFS }
261
262   where checkIfNeedsOpened uri = do
263           oldVFS <- vfs <$> get
264           ctx <- ask
265
266           -- if its not open, open it
267           unless (uri `Map.member` oldVFS) $ do
268             let fp = fromJust $ uriToFilePath uri
269             contents <- liftIO $ T.readFile fp
270             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
271                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
272             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
273
274             modifyM $ \s -> do
275               newVFS <- liftIO $ openVFS (vfs s) msg
276               return $ s { vfs = newVFS }
277
278         getParams (TextDocumentEdit docId (List edits)) =
279           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
280             in DidChangeTextDocumentParams docId (List changeEvents)
281
282         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
283
284         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
285
286         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
287
288         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
289         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
290                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
291 updateState _ = return ()
292
293 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
294 sendMessage msg = do
295   h <- serverIn <$> ask
296   logMsg LogClient msg
297   liftIO $ B.hPut h (addHeader $ encode msg)
298
299 -- | Execute a block f that will throw a 'TimeoutException'
300 -- after duration seconds. This will override the global timeout
301 -- for waiting for messages to arrive defined in 'SessionConfig'.
302 withTimeout :: Int -> Session a -> Session a
303 withTimeout duration f = do
304   chan <- asks messageChan
305   timeoutId <- curTimeoutId <$> get
306   modify $ \s -> s { overridingTimeout = True }
307   liftIO $ forkIO $ do
308     threadDelay (duration * 1000000)
309     writeChan chan (TimeoutMessage timeoutId)
310   res <- f
311   modify $ \s -> s { curTimeoutId = timeoutId + 1,
312                      overridingTimeout = False
313                    }
314   return res
315
316 data LogMsgType = LogServer | LogClient
317   deriving Eq
318
319 -- | Logs the message if the config specified it
320 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
321        => LogMsgType -> a -> m ()
322 logMsg t msg = do
323   shouldLog <- asks $ logMessages . config
324   shouldColor <- asks $ logColor . config
325   liftIO $ when shouldLog $ do
326     when shouldColor $ setSGR [SetColor Foreground Dull color]
327     putStrLn $ arrow ++ showPretty msg
328     when shouldColor $ setSGR [Reset]
329
330   where arrow
331           | t == LogServer  = "<-- "
332           | otherwise       = "--> "
333         color
334           | t == LogServer  = Magenta
335           | otherwise       = Cyan
336
337         showPretty = B.unpack . encodePretty
338