Don't use exitServer in Replay
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Compat
64 import Language.Haskell.LSP.Test.Decoding
65 import Language.Haskell.LSP.Test.Exceptions
66 import System.Console.ANSI
67 import System.Directory
68 import System.IO
69 import System.Process (ProcessHandle())
70 import System.Timeout
71
72 -- | A session representing one instance of launching and connecting to a server.
73 --
74 -- You can send and receive messages to the server within 'Session' via
75 -- 'Language.Haskell.LSP.Test.message',
76 -- 'Language.Haskell.LSP.Test.sendRequest' and
77 -- 'Language.Haskell.LSP.Test.sendNotification'.
78
79 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
80
81 #if __GLASGOW_HASKELL__ >= 806
82 instance MonadFail Session where
83   fail s = do
84     lastMsg <- fromJust . lastReceivedMessage <$> get
85     liftIO $ throw (UnexpectedMessage s lastMsg)
86 #endif
87
88 -- | Stuff you can configure for a 'Session'.
89 data SessionConfig = SessionConfig
90   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
91   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
92   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
93   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
94   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
95   }
96
97 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
98 defaultConfig :: SessionConfig
99 defaultConfig = SessionConfig 60 False False True Nothing
100
101 instance Default SessionConfig where
102   def = defaultConfig
103
104 data SessionMessage = ServerMessage FromServerMessage
105                     | TimeoutMessage Int
106   deriving Show
107
108 data SessionContext = SessionContext
109   {
110     serverIn :: Handle
111   , rootDir :: FilePath
112   , messageChan :: Chan SessionMessage
113   , requestMap :: MVar RequestMap
114   , initRsp :: MVar InitializeResponse
115   , config :: SessionConfig
116   , sessionCapabilities :: ClientCapabilities
117   }
118
119 class Monad m => HasReader r m where
120   ask :: m r
121   asks :: (r -> b) -> m b
122   asks f = f <$> ask
123
124 instance Monad m => HasReader r (ParserStateReader a s r m) where
125   ask = lift $ lift Reader.ask
126
127 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
128   ask = lift $ lift Reader.ask
129
130 data SessionState = SessionState
131   {
132     curReqId :: LspId
133   , vfs :: VFS
134   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
135   , curTimeoutId :: Int
136   , overridingTimeout :: Bool
137   -- ^ The last received message from the server.
138   -- Used for providing exception information
139   , lastReceivedMessage :: Maybe FromServerMessage
140   }
141
142 class Monad m => HasState s m where
143   get :: m s
144
145   put :: s -> m ()
146
147   modify :: (s -> s) -> m ()
148   modify f = get >>= put . f
149
150   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
151   modifyM f = get >>= f >>= put
152
153 instance Monad m => HasState s (ParserStateReader a s r m) where
154   get = lift State.get
155   put = lift . State.put
156
157 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
158  where
159   get = lift State.get
160   put = lift . State.put
161
162 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
163
164 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
165 runSession context state session = runReaderT (runStateT conduit state) context
166   where
167     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
168
169     handler (Unexpected "ConduitParser.empty") = do
170       lastMsg <- fromJust . lastReceivedMessage <$> get
171       name <- getParserName
172       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
173
174     handler e = throw e
175
176     chanSource = do
177       msg <- liftIO $ readChan (messageChan context)
178       yield msg
179       chanSource
180
181     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
182     watchdog = Conduit.awaitForever $ \msg -> do
183       curId <- curTimeoutId <$> get
184       case msg of
185         ServerMessage sMsg -> yield sMsg
186         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
187
188 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
189 -- It also does not automatically send initialize and exit messages.
190 runSessionWithHandles :: Handle -- ^ Server in
191                       -> Handle -- ^ Server out
192                       -> ProcessHandle -- ^ Server process
193                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
194                       -> SessionConfig
195                       -> ClientCapabilities
196                       -> FilePath -- ^ Root directory
197                       -> Session () -- ^ To exit the Server properly
198                       -> Session a
199                       -> IO a
200 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
201   absRootDir <- canonicalizePath rootDir
202
203   hSetBuffering serverIn  NoBuffering
204   hSetBuffering serverOut NoBuffering
205   -- This is required to make sure that we don’t get any
206   -- newline conversion or weird encoding issues.
207   hSetBinaryMode serverIn True
208   hSetBinaryMode serverOut True
209
210   reqMap <- newMVar newRequestMap
211   messageChan <- newChan
212   initRsp <- newEmptyMVar
213
214   mainThreadId <- myThreadId
215
216   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
217       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
218       runSession' = runSession context initState
219
220       errorHandler = throwTo mainThreadId :: SessionException -> IO()
221       serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
222       server = (Just serverIn, Just serverOut, Nothing, serverProc)
223       serverFinalizer tid = finally (timeout (messageTimeout config * 1000000)
224                                              (runSession' exitServer))
225                                     (cleanupRunningProcess server >> killThread tid)
226
227   (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
228   return result
229
230 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
231 updateStateC = awaitForever $ \msg -> do
232   updateState msg
233   yield msg
234
235 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
236 updateState (NotPublishDiagnostics n) = do
237   let List diags = n ^. params . diagnostics
238       doc = n ^. params . uri
239   modify (\s ->
240     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
241       in s { curDiagnostics = newDiags })
242
243 updateState (ReqApplyWorkspaceEdit r) = do
244
245   allChangeParams <- case r ^. params . edit . documentChanges of
246     Just (List cs) -> do
247       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
248       return $ map getParams cs
249     Nothing -> case r ^. params . edit . changes of
250       Just cs -> do
251         mapM_ checkIfNeedsOpened (HashMap.keys cs)
252         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
253       Nothing -> error "No changes!"
254
255   modifyM $ \s -> do
256     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
257     return $ s { vfs = newVFS }
258
259   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
260       mergedParams = map mergeParams groupedParams
261
262   -- TODO: Don't do this when replaying a session
263   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
264
265   -- Update VFS to new document versions
266   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
267       latestVersions = map ((^. textDocument) . last) sortedVersions
268       bumpedVersions = map (version . _Just +~ 1) latestVersions
269
270   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
271     modify $ \s ->
272       let oldVFS = vfs s
273           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
274           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
275       in s { vfs = newVFS }
276
277   where checkIfNeedsOpened uri = do
278           oldVFS <- vfs <$> get
279           ctx <- ask
280
281           -- if its not open, open it
282           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
283             let fp = fromJust $ uriToFilePath uri
284             contents <- liftIO $ T.readFile fp
285             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
286                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
287             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
288
289             modifyM $ \s -> do
290               newVFS <- liftIO $ openVFS (vfs s) msg
291               return $ s { vfs = newVFS }
292
293         getParams (TextDocumentEdit docId (List edits)) =
294           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
295             in DidChangeTextDocumentParams docId (List changeEvents)
296
297         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
298
299         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
300
301         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
302
303         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
304         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
305                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
306 updateState _ = return ()
307
308 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
309 sendMessage msg = do
310   h <- serverIn <$> ask
311   logMsg LogClient msg
312   liftIO $ B.hPut h (addHeader $ encode msg)
313
314 -- | Execute a block f that will throw a 'Timeout' exception
315 -- after duration seconds. This will override the global timeout
316 -- for waiting for messages to arrive defined in 'SessionConfig'.
317 withTimeout :: Int -> Session a -> Session a
318 withTimeout duration f = do
319   chan <- asks messageChan
320   timeoutId <- curTimeoutId <$> get
321   modify $ \s -> s { overridingTimeout = True }
322   liftIO $ forkIO $ do
323     threadDelay (duration * 1000000)
324     writeChan chan (TimeoutMessage timeoutId)
325   res <- f
326   modify $ \s -> s { curTimeoutId = timeoutId + 1,
327                      overridingTimeout = False
328                    }
329   return res
330
331 data LogMsgType = LogServer | LogClient
332   deriving Eq
333
334 -- | Logs the message if the config specified it
335 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
336        => LogMsgType -> a -> m ()
337 logMsg t msg = do
338   shouldLog <- asks $ logMessages . config
339   shouldColor <- asks $ logColor . config
340   liftIO $ when shouldLog $ do
341     when shouldColor $ setSGR [SetColor Foreground Dull color]
342     putStrLn $ arrow ++ showPretty msg
343     when shouldColor $ setSGR [Reset]
344
345   where arrow
346           | t == LogServer  = "<-- "
347           | otherwise       = "--> "
348         color
349           | t == LogServer  = Magenta
350           | otherwise       = Cyan
351
352         showPretty = B.unpack . encodePretty
353