Properly terminate server handler thread on exceptions
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68
69 -- | A session representing one instance of launching and connecting to a server.
70 -- 
71 -- You can send and receive messages to the server within 'Session' via 'getMessage',
72 -- 'sendRequest' and 'sendNotification'.
73 --
74
75 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
76
77 #if __GLASGOW_HASKELL__ >= 806
78 instance MonadFail Session where
79   fail s = do
80     lastMsg <- fromJust . lastReceivedMessage <$> get
81     liftIO $ throw (UnexpectedMessage s lastMsg)
82 #endif
83
84 -- | Stuff you can configure for a 'Session'.
85 data SessionConfig = SessionConfig
86   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
87   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
88   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
89   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
90   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
91   }
92
93 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
94 defaultConfig :: SessionConfig
95 defaultConfig = SessionConfig 60 False False True Nothing
96
97 instance Default SessionConfig where
98   def = defaultConfig
99
100 data SessionMessage = ServerMessage FromServerMessage
101                     | TimeoutMessage Int
102   deriving Show
103
104 data SessionContext = SessionContext
105   {
106     serverIn :: Handle
107   , rootDir :: FilePath
108   , messageChan :: Chan SessionMessage
109   , requestMap :: MVar RequestMap
110   , initRsp :: MVar InitializeResponse
111   , config :: SessionConfig
112   , sessionCapabilities :: ClientCapabilities
113   }
114
115 class Monad m => HasReader r m where
116   ask :: m r
117   asks :: (r -> b) -> m b
118   asks f = f <$> ask
119
120 instance Monad m => HasReader r (ParserStateReader a s r m) where
121   ask = lift $ lift Reader.ask
122
123 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
124   ask = lift $ lift Reader.ask
125
126 data SessionState = SessionState
127   {
128     curReqId :: LspId
129   , vfs :: VFS
130   , curDiagnostics :: Map.Map Uri [Diagnostic]
131   , curTimeoutId :: Int
132   , overridingTimeout :: Bool
133   -- ^ The last received message from the server.
134   -- Used for providing exception information
135   , lastReceivedMessage :: Maybe FromServerMessage
136   }
137
138 class Monad m => HasState s m where
139   get :: m s
140
141   put :: s -> m ()
142
143   modify :: (s -> s) -> m ()
144   modify f = get >>= put . f
145
146   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
147   modifyM f = get >>= f >>= put
148
149 instance Monad m => HasState s (ParserStateReader a s r m) where
150   get = lift State.get
151   put = lift . State.put
152
153 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
154  where
155   get = lift State.get
156   put = lift . State.put
157
158 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
159
160 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
161 runSession context state session = runReaderT (runStateT conduit state) context
162   where
163     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
164
165     handler (Unexpected "ConduitParser.empty") = do
166       lastMsg <- fromJust . lastReceivedMessage <$> get
167       name <- getParserName
168       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
169
170     handler e = throw e
171
172     chanSource = do
173       msg <- liftIO $ readChan (messageChan context)
174       yield msg
175       chanSource
176
177     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
178     watchdog = Conduit.awaitForever $ \msg -> do
179       curId <- curTimeoutId <$> get
180       case msg of
181         ServerMessage sMsg -> yield sMsg
182         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
183
184 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
185 -- It also does not automatically send initialize and exit messages.
186 runSessionWithHandles :: Handle -- ^ Server in
187                       -> Handle -- ^ Server out
188                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
189                       -> SessionConfig
190                       -> ClientCapabilities
191                       -> FilePath -- ^ Root directory
192                       -> Session a
193                       -> IO a
194 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
195   absRootDir <- canonicalizePath rootDir
196
197   hSetBuffering serverIn  NoBuffering
198   hSetBuffering serverOut NoBuffering
199
200   reqMap <- newMVar newRequestMap
201   messageChan <- newChan
202   initRsp <- newEmptyMVar
203
204   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
205       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
206       launchServerHandler = forkIO $ void $ serverHandler serverOut context
207   (result, _) <- bracket launchServerHandler killThread $
208     const $ runSession context initState session
209
210   return result
211
212 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
213 updateStateC = awaitForever $ \msg -> do
214   updateState msg
215   yield msg
216
217 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
218 updateState (NotPublishDiagnostics n) = do
219   let List diags = n ^. params . diagnostics
220       doc = n ^. params . uri
221   modify (\s ->
222     let newDiags = Map.insert doc diags (curDiagnostics s)
223       in s { curDiagnostics = newDiags })
224
225 updateState (ReqApplyWorkspaceEdit r) = do
226
227   allChangeParams <- case r ^. params . edit . documentChanges of
228     Just (List cs) -> do
229       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
230       return $ map getParams cs
231     Nothing -> case r ^. params . edit . changes of
232       Just cs -> do
233         mapM_ checkIfNeedsOpened (HashMap.keys cs)
234         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
235       Nothing -> error "No changes!"
236
237   modifyM $ \s -> do
238     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
239     return $ s { vfs = newVFS }
240
241   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
242       mergedParams = map mergeParams groupedParams
243
244   -- TODO: Don't do this when replaying a session
245   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
246
247   -- Update VFS to new document versions
248   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
249       latestVersions = map ((^. textDocument) . last) sortedVersions
250       bumpedVersions = map (version . _Just +~ 1) latestVersions
251
252   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
253     modify $ \s ->
254       let oldVFS = vfs s
255           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
256           newVFS = Map.adjust update uri oldVFS
257       in s { vfs = newVFS }
258
259   where checkIfNeedsOpened uri = do
260           oldVFS <- vfs <$> get
261           ctx <- ask
262
263           -- if its not open, open it
264           unless (uri `Map.member` oldVFS) $ do
265             let fp = fromJust $ uriToFilePath uri
266             contents <- liftIO $ T.readFile fp
267             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
268                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
269             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
270
271             modifyM $ \s -> do
272               newVFS <- liftIO $ openVFS (vfs s) msg
273               return $ s { vfs = newVFS }
274
275         getParams (TextDocumentEdit docId (List edits)) =
276           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
277             in DidChangeTextDocumentParams docId (List changeEvents)
278
279         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
280
281         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
282
283         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
284
285         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
286         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
287                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
288 updateState _ = return ()
289
290 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
291 sendMessage msg = do
292   h <- serverIn <$> ask
293   logMsg LogClient msg
294   liftIO $ B.hPut h (addHeader $ encode msg)
295
296 -- | Execute a block f that will throw a 'TimeoutException'
297 -- after duration seconds. This will override the global timeout
298 -- for waiting for messages to arrive defined in 'SessionConfig'.
299 withTimeout :: Int -> Session a -> Session a
300 withTimeout duration f = do
301   chan <- asks messageChan
302   timeoutId <- curTimeoutId <$> get
303   modify $ \s -> s { overridingTimeout = True }
304   liftIO $ forkIO $ do
305     threadDelay (duration * 1000000)
306     writeChan chan (TimeoutMessage timeoutId)
307   res <- f
308   modify $ \s -> s { curTimeoutId = timeoutId + 1,
309                      overridingTimeout = False
310                    }
311   return res
312
313 data LogMsgType = LogServer | LogClient
314   deriving Eq
315
316 -- | Logs the message if the config specified it
317 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
318        => LogMsgType -> a -> m ()
319 logMsg t msg = do
320   shouldLog <- asks $ logMessages . config
321   shouldColor <- asks $ logColor . config
322   liftIO $ when shouldLog $ do
323     when shouldColor $ setSGR [SetColor Foreground Dull color]
324     putStrLn $ arrow ++ showPretty msg
325     when shouldColor $ setSGR [Reset]
326
327   where arrow
328           | t == LogServer  = "<-- "
329           | otherwise       = "--> "
330         color
331           | t == LogServer  = Magenta
332           | otherwise       = Cyan
333
334         showPretty = B.unpack . encodePretty