Exit the server and its listener properly
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68
69 -- | A session representing one instance of launching and connecting to a server.
70 --
71 -- You can send and receive messages to the server within 'Session' via
72 -- 'Language.Haskell.LSP.Test.message',
73 -- 'Language.Haskell.LSP.Test.sendRequest' and
74 -- 'Language.Haskell.LSP.Test.sendNotification'.
75
76 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
77
78 #if __GLASGOW_HASKELL__ >= 806
79 instance MonadFail Session where
80   fail s = do
81     lastMsg <- fromJust . lastReceivedMessage <$> get
82     liftIO $ throw (UnexpectedMessage s lastMsg)
83 #endif
84
85 -- | Stuff you can configure for a 'Session'.
86 data SessionConfig = SessionConfig
87   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
88   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
89   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
90   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
91   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
92   }
93
94 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
95 defaultConfig :: SessionConfig
96 defaultConfig = SessionConfig 60 False False True Nothing
97
98 instance Default SessionConfig where
99   def = defaultConfig
100
101 data SessionMessage = ServerMessage FromServerMessage
102                     | TimeoutMessage Int
103   deriving Show
104
105 data SessionContext = SessionContext
106   {
107     serverIn :: Handle
108   , rootDir :: FilePath
109   , messageChan :: Chan SessionMessage
110   , requestMap :: MVar RequestMap
111   , initRsp :: MVar InitializeResponse
112   , config :: SessionConfig
113   , sessionCapabilities :: ClientCapabilities
114   }
115
116 class Monad m => HasReader r m where
117   ask :: m r
118   asks :: (r -> b) -> m b
119   asks f = f <$> ask
120
121 instance Monad m => HasReader r (ParserStateReader a s r m) where
122   ask = lift $ lift Reader.ask
123
124 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
125   ask = lift $ lift Reader.ask
126
127 data SessionState = SessionState
128   {
129     curReqId :: LspId
130   , vfs :: VFS
131   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
132   , curTimeoutId :: Int
133   , overridingTimeout :: Bool
134   -- ^ The last received message from the server.
135   -- Used for providing exception information
136   , lastReceivedMessage :: Maybe FromServerMessage
137   }
138
139 class Monad m => HasState s m where
140   get :: m s
141
142   put :: s -> m ()
143
144   modify :: (s -> s) -> m ()
145   modify f = get >>= put . f
146
147   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
148   modifyM f = get >>= f >>= put
149
150 instance Monad m => HasState s (ParserStateReader a s r m) where
151   get = lift State.get
152   put = lift . State.put
153
154 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
155  where
156   get = lift State.get
157   put = lift . State.put
158
159 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
160
161 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
162 runSession context state session = runReaderT (runStateT conduit state) context
163   where
164     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
165
166     handler (Unexpected "ConduitParser.empty") = do
167       lastMsg <- fromJust . lastReceivedMessage <$> get
168       name <- getParserName
169       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
170
171     handler e = throw e
172
173     chanSource = do
174       msg <- liftIO $ readChan (messageChan context)
175       yield msg
176       chanSource
177
178     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
179     watchdog = Conduit.awaitForever $ \msg -> do
180       curId <- curTimeoutId <$> get
181       case msg of
182         ServerMessage sMsg -> yield sMsg
183         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
184
185 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
186 -- It also does not automatically send initialize and exit messages.
187 runSessionWithHandles :: Handle -- ^ Server in
188                       -> Handle -- ^ Server out
189                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
190                       -> SessionConfig
191                       -> ClientCapabilities
192                       -> FilePath -- ^ Root directory
193                       -> Session () -- ^ To exit Server
194                       -> Session a
195                       -> IO a
196 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir exitServer session = do
197   
198   absRootDir <- canonicalizePath rootDir
199
200   hSetBuffering serverIn  NoBuffering
201   hSetBuffering serverOut NoBuffering
202   -- This is required to make sure that we don’t get any
203   -- newline conversion or weird encoding issues.
204   hSetBinaryMode serverIn True
205   hSetBinaryMode serverOut True
206
207   reqMap <- newMVar newRequestMap
208   messageChan <- newChan
209   initRsp <- newEmptyMVar
210
211   mainThreadId <- myThreadId
212
213   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
214       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
215       runSession' = runSession context initState
216       
217       errorHandler = throwTo mainThreadId :: SessionException -> IO()
218       serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
219       serverFinalizer tid = runSession' exitServer >> killThread tid
220       
221   (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
222   return result
223
224 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
225 updateStateC = awaitForever $ \msg -> do
226   updateState msg
227   yield msg
228
229 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
230 updateState (NotPublishDiagnostics n) = do
231   let List diags = n ^. params . diagnostics
232       doc = n ^. params . uri
233   modify (\s ->
234     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
235       in s { curDiagnostics = newDiags })
236
237 updateState (ReqApplyWorkspaceEdit r) = do
238
239   allChangeParams <- case r ^. params . edit . documentChanges of
240     Just (List cs) -> do
241       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
242       return $ map getParams cs
243     Nothing -> case r ^. params . edit . changes of
244       Just cs -> do
245         mapM_ checkIfNeedsOpened (HashMap.keys cs)
246         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
247       Nothing -> error "No changes!"
248
249   modifyM $ \s -> do
250     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
251     return $ s { vfs = newVFS }
252
253   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
254       mergedParams = map mergeParams groupedParams
255
256   -- TODO: Don't do this when replaying a session
257   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
258
259   -- Update VFS to new document versions
260   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
261       latestVersions = map ((^. textDocument) . last) sortedVersions
262       bumpedVersions = map (version . _Just +~ 1) latestVersions
263
264   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
265     modify $ \s ->
266       let oldVFS = vfs s
267           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
268           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
269       in s { vfs = newVFS }
270
271   where checkIfNeedsOpened uri = do
272           oldVFS <- vfs <$> get
273           ctx <- ask
274
275           -- if its not open, open it
276           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
277             let fp = fromJust $ uriToFilePath uri
278             contents <- liftIO $ T.readFile fp
279             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
280                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
281             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
282
283             modifyM $ \s -> do
284               newVFS <- liftIO $ openVFS (vfs s) msg
285               return $ s { vfs = newVFS }
286
287         getParams (TextDocumentEdit docId (List edits)) =
288           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
289             in DidChangeTextDocumentParams docId (List changeEvents)
290
291         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
292
293         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
294
295         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
296
297         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
298         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
299                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
300 updateState _ = return ()
301
302 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
303 sendMessage msg = do
304   h <- serverIn <$> ask
305   logMsg LogClient msg
306   liftIO $ B.hPut h (addHeader $ encode msg)
307
308 -- | Execute a block f that will throw a 'Timeout' exception
309 -- after duration seconds. This will override the global timeout
310 -- for waiting for messages to arrive defined in 'SessionConfig'.
311 withTimeout :: Int -> Session a -> Session a
312 withTimeout duration f = do
313   chan <- asks messageChan
314   timeoutId <- curTimeoutId <$> get
315   modify $ \s -> s { overridingTimeout = True }
316   liftIO $ forkIO $ do
317     threadDelay (duration * 1000000)
318     writeChan chan (TimeoutMessage timeoutId)
319   res <- f
320   modify $ \s -> s { curTimeoutId = timeoutId + 1,
321                      overridingTimeout = False
322                    }
323   return res
324
325 data LogMsgType = LogServer | LogClient
326   deriving Eq
327
328 -- | Logs the message if the config specified it
329 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
330        => LogMsgType -> a -> m ()
331 logMsg t msg = do
332   shouldLog <- asks $ logMessages . config
333   shouldColor <- asks $ logColor . config
334   liftIO $ when shouldLog $ do
335     when shouldColor $ setSGR [SetColor Foreground Dull color]
336     putStrLn $ arrow ++ showPretty msg
337     when shouldColor $ setSGR [Reset]
338
339   where arrow
340           | t == LogServer  = "<-- "
341           | otherwise       = "--> "
342         color
343           | t == LogServer  = Magenta
344           | otherwise       = Cyan
345
346         showPretty = B.unpack . encodePretty
347