Add a finally and timeout to ensure the call to killThread
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68 import System.Timeout
69
70 -- | A session representing one instance of launching and connecting to a server.
71 --
72 -- You can send and receive messages to the server within 'Session' via
73 -- 'Language.Haskell.LSP.Test.message',
74 -- 'Language.Haskell.LSP.Test.sendRequest' and
75 -- 'Language.Haskell.LSP.Test.sendNotification'.
76
77 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
78
79 #if __GLASGOW_HASKELL__ >= 806
80 instance MonadFail Session where
81   fail s = do
82     lastMsg <- fromJust . lastReceivedMessage <$> get
83     liftIO $ throw (UnexpectedMessage s lastMsg)
84 #endif
85
86 -- | Stuff you can configure for a 'Session'.
87 data SessionConfig = SessionConfig
88   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
89   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
90   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
91   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
92   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
93   }
94
95 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
96 defaultConfig :: SessionConfig
97 defaultConfig = SessionConfig 60 False False True Nothing
98
99 instance Default SessionConfig where
100   def = defaultConfig
101
102 data SessionMessage = ServerMessage FromServerMessage
103                     | TimeoutMessage Int
104   deriving Show
105
106 data SessionContext = SessionContext
107   {
108     serverIn :: Handle
109   , rootDir :: FilePath
110   , messageChan :: Chan SessionMessage
111   , requestMap :: MVar RequestMap
112   , initRsp :: MVar InitializeResponse
113   , config :: SessionConfig
114   , sessionCapabilities :: ClientCapabilities
115   }
116
117 class Monad m => HasReader r m where
118   ask :: m r
119   asks :: (r -> b) -> m b
120   asks f = f <$> ask
121
122 instance Monad m => HasReader r (ParserStateReader a s r m) where
123   ask = lift $ lift Reader.ask
124
125 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
126   ask = lift $ lift Reader.ask
127
128 data SessionState = SessionState
129   {
130     curReqId :: LspId
131   , vfs :: VFS
132   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
133   , curTimeoutId :: Int
134   , overridingTimeout :: Bool
135   -- ^ The last received message from the server.
136   -- Used for providing exception information
137   , lastReceivedMessage :: Maybe FromServerMessage
138   }
139
140 class Monad m => HasState s m where
141   get :: m s
142
143   put :: s -> m ()
144
145   modify :: (s -> s) -> m ()
146   modify f = get >>= put . f
147
148   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
149   modifyM f = get >>= f >>= put
150
151 instance Monad m => HasState s (ParserStateReader a s r m) where
152   get = lift State.get
153   put = lift . State.put
154
155 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
156  where
157   get = lift State.get
158   put = lift . State.put
159
160 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
161
162 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
163 runSession context state session = runReaderT (runStateT conduit state) context
164   where
165     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
166
167     handler (Unexpected "ConduitParser.empty") = do
168       lastMsg <- fromJust . lastReceivedMessage <$> get
169       name <- getParserName
170       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
171
172     handler e = throw e
173
174     chanSource = do
175       msg <- liftIO $ readChan (messageChan context)
176       yield msg
177       chanSource
178
179     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
180     watchdog = Conduit.awaitForever $ \msg -> do
181       curId <- curTimeoutId <$> get
182       case msg of
183         ServerMessage sMsg -> yield sMsg
184         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
185
186 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
187 -- It also does not automatically send initialize and exit messages.
188 runSessionWithHandles :: Handle -- ^ Server in
189                       -> Handle -- ^ Server out
190                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
191                       -> SessionConfig
192                       -> ClientCapabilities
193                       -> FilePath -- ^ Root directory
194                       -> Session () -- ^ To exit Server
195                       -> Session a
196                       -> IO a
197 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir exitServer session = do
198   
199   absRootDir <- canonicalizePath rootDir
200
201   hSetBuffering serverIn  NoBuffering
202   hSetBuffering serverOut NoBuffering
203   -- This is required to make sure that we don’t get any
204   -- newline conversion or weird encoding issues.
205   hSetBinaryMode serverIn True
206   hSetBinaryMode serverOut True
207
208   reqMap <- newMVar newRequestMap
209   messageChan <- newChan
210   initRsp <- newEmptyMVar
211
212   mainThreadId <- myThreadId
213
214   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
215       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
216       runSession' = runSession context initState
217       
218       errorHandler = throwTo mainThreadId :: SessionException -> IO()
219       serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
220       serverFinalizer tid = finally (timeout 60000000 (runSession' exitServer))
221                                     (killThread tid)
222       
223   (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
224   return result
225
226 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
227 updateStateC = awaitForever $ \msg -> do
228   updateState msg
229   yield msg
230
231 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
232 updateState (NotPublishDiagnostics n) = do
233   let List diags = n ^. params . diagnostics
234       doc = n ^. params . uri
235   modify (\s ->
236     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
237       in s { curDiagnostics = newDiags })
238
239 updateState (ReqApplyWorkspaceEdit r) = do
240
241   allChangeParams <- case r ^. params . edit . documentChanges of
242     Just (List cs) -> do
243       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
244       return $ map getParams cs
245     Nothing -> case r ^. params . edit . changes of
246       Just cs -> do
247         mapM_ checkIfNeedsOpened (HashMap.keys cs)
248         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
249       Nothing -> error "No changes!"
250
251   modifyM $ \s -> do
252     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
253     return $ s { vfs = newVFS }
254
255   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
256       mergedParams = map mergeParams groupedParams
257
258   -- TODO: Don't do this when replaying a session
259   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
260
261   -- Update VFS to new document versions
262   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
263       latestVersions = map ((^. textDocument) . last) sortedVersions
264       bumpedVersions = map (version . _Just +~ 1) latestVersions
265
266   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
267     modify $ \s ->
268       let oldVFS = vfs s
269           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
270           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
271       in s { vfs = newVFS }
272
273   where checkIfNeedsOpened uri = do
274           oldVFS <- vfs <$> get
275           ctx <- ask
276
277           -- if its not open, open it
278           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
279             let fp = fromJust $ uriToFilePath uri
280             contents <- liftIO $ T.readFile fp
281             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
282                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
283             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
284
285             modifyM $ \s -> do
286               newVFS <- liftIO $ openVFS (vfs s) msg
287               return $ s { vfs = newVFS }
288
289         getParams (TextDocumentEdit docId (List edits)) =
290           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
291             in DidChangeTextDocumentParams docId (List changeEvents)
292
293         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
294
295         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
296
297         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
298
299         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
300         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
301                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
302 updateState _ = return ()
303
304 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
305 sendMessage msg = do
306   h <- serverIn <$> ask
307   logMsg LogClient msg
308   liftIO $ B.hPut h (addHeader $ encode msg)
309
310 -- | Execute a block f that will throw a 'Timeout' exception
311 -- after duration seconds. This will override the global timeout
312 -- for waiting for messages to arrive defined in 'SessionConfig'.
313 withTimeout :: Int -> Session a -> Session a
314 withTimeout duration f = do
315   chan <- asks messageChan
316   timeoutId <- curTimeoutId <$> get
317   modify $ \s -> s { overridingTimeout = True }
318   liftIO $ forkIO $ do
319     threadDelay (duration * 1000000)
320     writeChan chan (TimeoutMessage timeoutId)
321   res <- f
322   modify $ \s -> s { curTimeoutId = timeoutId + 1,
323                      overridingTimeout = False
324                    }
325   return res
326
327 data LogMsgType = LogServer | LogClient
328   deriving Eq
329
330 -- | Logs the message if the config specified it
331 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
332        => LogMsgType -> a -> m ()
333 logMsg t msg = do
334   shouldLog <- asks $ logMessages . config
335   shouldColor <- asks $ logColor . config
336   liftIO $ when shouldLog $ do
337     when shouldColor $ setSGR [SetColor Foreground Dull color]
338     putStrLn $ arrow ++ showPretty msg
339     when shouldColor $ setSGR [Reset]
340
341   where arrow
342           | t == LogServer  = "<-- "
343           | otherwise       = "--> "
344         color
345           | t == LogServer  = Magenta
346           | otherwise       = Cyan
347
348         showPretty = B.unpack . encodePretty
349