Add server shutdown check to throw exception
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.IORef
52 import Data.List
53 import qualified Data.Map as Map
54 import qualified Data.Text as T
55 import qualified Data.Text.IO as T
56 import qualified Data.HashMap.Strict as HashMap
57 import Data.Maybe
58 import Data.Function
59 import Language.Haskell.LSP.Messages
60 import Language.Haskell.LSP.Types.Capabilities
61 import Language.Haskell.LSP.Types
62 import Language.Haskell.LSP.Types.Lens hiding (error)
63 import Language.Haskell.LSP.VFS
64 import Language.Haskell.LSP.Test.Decoding
65 import Language.Haskell.LSP.Test.Exceptions
66 import System.Console.ANSI
67 import System.Directory
68 import System.IO
69
70 -- | A session representing one instance of launching and connecting to a server.
71 --
72 -- You can send and receive messages to the server within 'Session' via
73 -- 'Language.Haskell.LSP.Test.message',
74 -- 'Language.Haskell.LSP.Test.sendRequest' and
75 -- 'Language.Haskell.LSP.Test.sendNotification'.
76
77 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
78
79 #if __GLASGOW_HASKELL__ >= 806
80 instance MonadFail Session where
81   fail s = do
82     lastMsg <- fromJust . lastReceivedMessage <$> get
83     liftIO $ throw (UnexpectedMessage s lastMsg)
84 #endif
85
86 -- | Stuff you can configure for a 'Session'.
87 data SessionConfig = SessionConfig
88   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
89   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
90   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
91   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
92   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
93   }
94
95 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
96 defaultConfig :: SessionConfig
97 defaultConfig = SessionConfig 60 False False True Nothing
98
99 instance Default SessionConfig where
100   def = defaultConfig
101
102 data SessionMessage = ServerMessage FromServerMessage
103                     | TimeoutMessage Int
104   deriving Show
105
106 data SessionContext = SessionContext
107   {
108     serverIn :: Handle
109   , rootDir :: FilePath
110   , messageChan :: Chan SessionMessage
111   , requestMap :: MVar RequestMap
112   , initRsp :: MVar InitializeResponse
113   , config :: SessionConfig
114   , sessionCapabilities :: ClientCapabilities
115   }
116
117 class Monad m => HasReader r m where
118   ask :: m r
119   asks :: (r -> b) -> m b
120   asks f = f <$> ask
121
122 instance Monad m => HasReader r (ParserStateReader a s r m) where
123   ask = lift $ lift Reader.ask
124
125 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
126   ask = lift $ lift Reader.ask
127
128 data SessionState = SessionState
129   {
130     curReqId :: LspId
131   , vfs :: VFS
132   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
133   , curTimeoutId :: Int
134   , overridingTimeout :: Bool
135   -- ^ The last received message from the server.
136   -- Used for providing exception information
137   , lastReceivedMessage :: Maybe FromServerMessage
138   }
139
140 class Monad m => HasState s m where
141   get :: m s
142
143   put :: s -> m ()
144
145   modify :: (s -> s) -> m ()
146   modify f = get >>= put . f
147
148   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
149   modifyM f = get >>= f >>= put
150
151 instance Monad m => HasState s (ParserStateReader a s r m) where
152   get = lift State.get
153   put = lift . State.put
154
155 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
156  where
157   get = lift State.get
158   put = lift . State.put
159
160 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
161
162 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
163 runSession context state session = runReaderT (runStateT conduit state) context
164   where
165     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
166
167     handler (Unexpected "ConduitParser.empty") = do
168       lastMsg <- fromJust . lastReceivedMessage <$> get
169       name <- getParserName
170       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
171
172     handler e = throw e
173
174     chanSource = do
175       msg <- liftIO $ readChan (messageChan context)
176       yield msg
177       chanSource
178
179     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
180     watchdog = Conduit.awaitForever $ \msg -> do
181       curId <- curTimeoutId <$> get
182       case msg of
183         ServerMessage sMsg -> yield sMsg
184         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
185
186 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
187 -- It also does not automatically send initialize and exit messages.
188 runSessionWithHandles :: Handle -- ^ Server in
189                       -> Handle -- ^ Server out
190                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
191                       -> SessionConfig
192                       -> ClientCapabilities
193                       -> FilePath -- ^ Root directory
194                       -> Session a
195                       -> IO a
196 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
197   -- We use this IORef to make exception non-fatal when the server is supposed to shutdown.
198
199   exitOk <- newIORef False
200   
201   absRootDir <- canonicalizePath rootDir
202
203   hSetBuffering serverIn  NoBuffering
204   hSetBuffering serverOut NoBuffering
205   -- This is required to make sure that we don’t get any
206   -- newline conversion or weird encoding issues.
207   hSetBinaryMode serverIn True
208   hSetBinaryMode serverOut True
209
210   reqMap <- newMVar newRequestMap
211   messageChan <- newChan
212   initRsp <- newEmptyMVar
213
214   mainThreadId <- myThreadId
215
216   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
217       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
218       errorHandler ex = do x <- readIORef exitOk 
219                            unless x $ throwTo mainThreadId (ex :: SessionException)
220       launchServerHandler = forkIO $ catch (serverHandler serverOut context) errorHandler
221   (result, _) <- bracket
222                    launchServerHandler
223                    (\tid -> do runSession context initState sendExitMessage
224                                killThread tid
225                                atomicWriteIORef exitOk True)
226                    (const $ runSession context initState session)
227   return result
228
229 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
230 updateStateC = awaitForever $ \msg -> do
231   updateState msg
232   yield msg
233
234 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
235 updateState (NotPublishDiagnostics n) = do
236   let List diags = n ^. params . diagnostics
237       doc = n ^. params . uri
238   modify (\s ->
239     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
240       in s { curDiagnostics = newDiags })
241
242 updateState (ReqApplyWorkspaceEdit r) = do
243
244   allChangeParams <- case r ^. params . edit . documentChanges of
245     Just (List cs) -> do
246       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
247       return $ map getParams cs
248     Nothing -> case r ^. params . edit . changes of
249       Just cs -> do
250         mapM_ checkIfNeedsOpened (HashMap.keys cs)
251         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
252       Nothing -> error "No changes!"
253
254   modifyM $ \s -> do
255     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
256     return $ s { vfs = newVFS }
257
258   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
259       mergedParams = map mergeParams groupedParams
260
261   -- TODO: Don't do this when replaying a session
262   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
263
264   -- Update VFS to new document versions
265   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
266       latestVersions = map ((^. textDocument) . last) sortedVersions
267       bumpedVersions = map (version . _Just +~ 1) latestVersions
268
269   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
270     modify $ \s ->
271       let oldVFS = vfs s
272           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
273           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
274       in s { vfs = newVFS }
275
276   where checkIfNeedsOpened uri = do
277           oldVFS <- vfs <$> get
278           ctx <- ask
279
280           -- if its not open, open it
281           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
282             let fp = fromJust $ uriToFilePath uri
283             contents <- liftIO $ T.readFile fp
284             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
285                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
286             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
287
288             modifyM $ \s -> do
289               newVFS <- liftIO $ openVFS (vfs s) msg
290               return $ s { vfs = newVFS }
291
292         getParams (TextDocumentEdit docId (List edits)) =
293           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
294             in DidChangeTextDocumentParams docId (List changeEvents)
295
296         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
297
298         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
299
300         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
301
302         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
303         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
304                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
305 updateState _ = return ()
306
307 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
308 sendMessage msg = do
309   h <- serverIn <$> ask
310   logMsg LogClient msg
311   liftIO $ B.hPut h (addHeader $ encode msg)
312
313 sendExitMessage :: (MonadIO m, HasReader SessionContext m) => m ()
314 sendExitMessage = sendMessage (NotificationMessage "2.0" Exit ExitParams)
315
316 -- | Execute a block f that will throw a 'Timeout' exception
317 -- after duration seconds. This will override the global timeout
318 -- for waiting for messages to arrive defined in 'SessionConfig'.
319 withTimeout :: Int -> Session a -> Session a
320 withTimeout duration f = do
321   chan <- asks messageChan
322   timeoutId <- curTimeoutId <$> get
323   modify $ \s -> s { overridingTimeout = True }
324   liftIO $ forkIO $ do
325     threadDelay (duration * 1000000)
326     writeChan chan (TimeoutMessage timeoutId)
327   res <- f
328   modify $ \s -> s { curTimeoutId = timeoutId + 1,
329                      overridingTimeout = False
330                    }
331   return res
332
333 data LogMsgType = LogServer | LogClient
334   deriving Eq
335
336 -- | Logs the message if the config specified it
337 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
338        => LogMsgType -> a -> m ()
339 logMsg t msg = do
340   shouldLog <- asks $ logMessages . config
341   shouldColor <- asks $ logColor . config
342   liftIO $ when shouldLog $ do
343     when shouldColor $ setSGR [SetColor Foreground Dull color]
344     putStrLn $ arrow ++ showPretty msg
345     when shouldColor $ setSGR [Reset]
346
347   where arrow
348           | t == LogServer  = "<-- "
349           | otherwise       = "--> "
350         color
351           | t == LogServer  = Magenta
352           | otherwise       = Cyan
353
354         showPretty = B.unpack . encodePretty
355