Shutdown the server before kill its thread
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68
69 -- | A session representing one instance of launching and connecting to a server.
70 --
71 -- You can send and receive messages to the server within 'Session' via
72 -- 'Language.Haskell.LSP.Test.message',
73 -- 'Language.Haskell.LSP.Test.sendRequest' and
74 -- 'Language.Haskell.LSP.Test.sendNotification'.
75
76 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
77
78 #if __GLASGOW_HASKELL__ >= 806
79 instance MonadFail Session where
80   fail s = do
81     lastMsg <- fromJust . lastReceivedMessage <$> get
82     liftIO $ throw (UnexpectedMessage s lastMsg)
83 #endif
84
85 -- | Stuff you can configure for a 'Session'.
86 data SessionConfig = SessionConfig
87   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
88   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
89   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
90   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
91   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
92   }
93
94 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
95 defaultConfig :: SessionConfig
96 defaultConfig = SessionConfig 60 False False True Nothing
97
98 instance Default SessionConfig where
99   def = defaultConfig
100
101 data SessionMessage = ServerMessage FromServerMessage
102                     | TimeoutMessage Int
103   deriving Show
104
105 data SessionContext = SessionContext
106   {
107     serverIn :: Handle
108   , rootDir :: FilePath
109   , messageChan :: Chan SessionMessage
110   , requestMap :: MVar RequestMap
111   , initRsp :: MVar InitializeResponse
112   , config :: SessionConfig
113   , sessionCapabilities :: ClientCapabilities
114   }
115
116 class Monad m => HasReader r m where
117   ask :: m r
118   asks :: (r -> b) -> m b
119   asks f = f <$> ask
120
121 instance Monad m => HasReader r (ParserStateReader a s r m) where
122   ask = lift $ lift Reader.ask
123
124 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
125   ask = lift $ lift Reader.ask
126
127 data SessionState = SessionState
128   {
129     curReqId :: LspId
130   , vfs :: VFS
131   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
132   , curTimeoutId :: Int
133   , overridingTimeout :: Bool
134   -- ^ The last received message from the server.
135   -- Used for providing exception information
136   , lastReceivedMessage :: Maybe FromServerMessage
137   }
138
139 class Monad m => HasState s m where
140   get :: m s
141
142   put :: s -> m ()
143
144   modify :: (s -> s) -> m ()
145   modify f = get >>= put . f
146
147   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
148   modifyM f = get >>= f >>= put
149
150 instance Monad m => HasState s (ParserStateReader a s r m) where
151   get = lift State.get
152   put = lift . State.put
153
154 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
155  where
156   get = lift State.get
157   put = lift . State.put
158
159 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
160
161 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
162 runSession context state session = runReaderT (runStateT conduit state) context
163   where
164     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
165
166     handler (Unexpected "ConduitParser.empty") = do
167       lastMsg <- fromJust . lastReceivedMessage <$> get
168       name <- getParserName
169       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
170
171     handler e = throw e
172
173     chanSource = do
174       msg <- liftIO $ readChan (messageChan context)
175       yield msg
176       chanSource
177
178     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
179     watchdog = Conduit.awaitForever $ \msg -> do
180       curId <- curTimeoutId <$> get
181       case msg of
182         ServerMessage sMsg -> yield sMsg
183         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
184
185 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
186 -- It also does not automatically send initialize and exit messages.
187 runSessionWithHandles :: Handle -- ^ Server in
188                       -> Handle -- ^ Server out
189                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
190                       -> SessionConfig
191                       -> ClientCapabilities
192                       -> FilePath -- ^ Root directory
193                       -> Session a
194                       -> IO a
195 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
196   absRootDir <- canonicalizePath rootDir
197
198   hSetBuffering serverIn  NoBuffering
199   hSetBuffering serverOut NoBuffering
200   -- This is required to make sure that we don’t get any
201   -- newline conversion or weird encoding issues.
202   hSetBinaryMode serverIn True
203   hSetBinaryMode serverOut True
204
205   reqMap <- newMVar newRequestMap
206   messageChan <- newChan
207   initRsp <- newEmptyMVar
208
209   mainThreadId <- myThreadId
210
211   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
212       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
213       launchServerHandler = forkIO $ catch (serverHandler serverOut context)
214                                            (throwTo mainThreadId :: SessionException -> IO())
215   (result, _) <- bracket
216                    launchServerHandler
217                    (\tid -> do runSession context initState sendExitMessage
218                                killThread tid)
219                    (const $ runSession context initState session)
220   return result
221
222 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
223 updateStateC = awaitForever $ \msg -> do
224   updateState msg
225   yield msg
226
227 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
228 updateState (NotPublishDiagnostics n) = do
229   let List diags = n ^. params . diagnostics
230       doc = n ^. params . uri
231   modify (\s ->
232     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
233       in s { curDiagnostics = newDiags })
234
235 updateState (ReqApplyWorkspaceEdit r) = do
236
237   allChangeParams <- case r ^. params . edit . documentChanges of
238     Just (List cs) -> do
239       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
240       return $ map getParams cs
241     Nothing -> case r ^. params . edit . changes of
242       Just cs -> do
243         mapM_ checkIfNeedsOpened (HashMap.keys cs)
244         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
245       Nothing -> error "No changes!"
246
247   modifyM $ \s -> do
248     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
249     return $ s { vfs = newVFS }
250
251   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
252       mergedParams = map mergeParams groupedParams
253
254   -- TODO: Don't do this when replaying a session
255   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
256
257   -- Update VFS to new document versions
258   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
259       latestVersions = map ((^. textDocument) . last) sortedVersions
260       bumpedVersions = map (version . _Just +~ 1) latestVersions
261
262   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
263     modify $ \s ->
264       let oldVFS = vfs s
265           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
266           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
267       in s { vfs = newVFS }
268
269   where checkIfNeedsOpened uri = do
270           oldVFS <- vfs <$> get
271           ctx <- ask
272
273           -- if its not open, open it
274           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
275             let fp = fromJust $ uriToFilePath uri
276             contents <- liftIO $ T.readFile fp
277             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
278                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
279             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
280
281             modifyM $ \s -> do
282               newVFS <- liftIO $ openVFS (vfs s) msg
283               return $ s { vfs = newVFS }
284
285         getParams (TextDocumentEdit docId (List edits)) =
286           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
287             in DidChangeTextDocumentParams docId (List changeEvents)
288
289         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
290
291         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
292
293         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
294
295         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
296         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
297                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
298 updateState _ = return ()
299
300 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
301 sendMessage msg = do
302   h <- serverIn <$> ask
303   logMsg LogClient msg
304   liftIO $ B.hPut h (addHeader $ encode msg)
305
306 sendExitMessage :: (MonadIO m, HasReader SessionContext m) => m ()
307 sendExitMessage = sendMessage (NotificationMessage "2.0" Exit ExitParams)
308
309 -- | Execute a block f that will throw a 'Timeout' exception
310 -- after duration seconds. This will override the global timeout
311 -- for waiting for messages to arrive defined in 'SessionConfig'.
312 withTimeout :: Int -> Session a -> Session a
313 withTimeout duration f = do
314   chan <- asks messageChan
315   timeoutId <- curTimeoutId <$> get
316   modify $ \s -> s { overridingTimeout = True }
317   liftIO $ forkIO $ do
318     threadDelay (duration * 1000000)
319     writeChan chan (TimeoutMessage timeoutId)
320   res <- f
321   modify $ \s -> s { curTimeoutId = timeoutId + 1,
322                      overridingTimeout = False
323                    }
324   return res
325
326 data LogMsgType = LogServer | LogClient
327   deriving Eq
328
329 -- | Logs the message if the config specified it
330 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
331        => LogMsgType -> a -> m ()
332 logMsg t msg = do
333   shouldLog <- asks $ logMessages . config
334   shouldColor <- asks $ logColor . config
335   liftIO $ when shouldLog $ do
336     when shouldColor $ setSGR [SetColor Foreground Dull color]
337     putStrLn $ arrow ++ showPretty msg
338     when shouldColor $ setSGR [Reset]
339
340   where arrow
341           | t == LogServer  = "<-- "
342           | otherwise       = "--> "
343         color
344           | t == LogServer  = Magenta
345           | otherwise       = Cyan
346
347         showPretty = B.unpack . encodePretty
348