8e1afa8c62e515b661576fcf43190cced7e47a15
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Compat
64 import Language.Haskell.LSP.Test.Decoding
65 import Language.Haskell.LSP.Test.Exceptions
66 import System.Console.ANSI
67 import System.Directory
68 import System.IO
69 import System.Process (ProcessHandle())
70 import System.Timeout
71
72 -- | A session representing one instance of launching and connecting to a server.
73 --
74 -- You can send and receive messages to the server within 'Session' via
75 -- 'Language.Haskell.LSP.Test.message',
76 -- 'Language.Haskell.LSP.Test.sendRequest' and
77 -- 'Language.Haskell.LSP.Test.sendNotification'.
78
79 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
80
81 #if __GLASGOW_HASKELL__ >= 806
82 instance MonadFail Session where
83   fail s = do
84     lastMsg <- fromJust . lastReceivedMessage <$> get
85     liftIO $ throw (UnexpectedMessage s lastMsg)
86 #endif
87
88 -- | Stuff you can configure for a 'Session'.
89 data SessionConfig = SessionConfig
90   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
91   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
92   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
93   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
94   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
95   }
96
97 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
98 defaultConfig :: SessionConfig
99 defaultConfig = SessionConfig 60 False False True Nothing
100
101 instance Default SessionConfig where
102   def = defaultConfig
103
104 data SessionMessage = ServerMessage FromServerMessage
105                     | TimeoutMessage Int
106   deriving Show
107
108 data SessionContext = SessionContext
109   {
110     serverIn :: Handle
111   , rootDir :: FilePath
112   , messageChan :: Chan SessionMessage
113   , requestMap :: MVar RequestMap
114   , initRsp :: MVar InitializeResponse
115   , config :: SessionConfig
116   , sessionCapabilities :: ClientCapabilities
117   }
118
119 class Monad m => HasReader r m where
120   ask :: m r
121   asks :: (r -> b) -> m b
122   asks f = f <$> ask
123
124 instance Monad m => HasReader r (ParserStateReader a s r m) where
125   ask = lift $ lift Reader.ask
126
127 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
128   ask = lift $ lift Reader.ask
129
130 data SessionState = SessionState
131   {
132     curReqId :: LspId
133   , vfs :: VFS
134   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
135   , curTimeoutId :: Int
136   , overridingTimeout :: Bool
137   -- ^ The last received message from the server.
138   -- Used for providing exception information
139   , lastReceivedMessage :: Maybe FromServerMessage
140   }
141
142 class Monad m => HasState s m where
143   get :: m s
144
145   put :: s -> m ()
146
147   modify :: (s -> s) -> m ()
148   modify f = get >>= put . f
149
150   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
151   modifyM f = get >>= f >>= put
152
153 instance Monad m => HasState s (ParserStateReader a s r m) where
154   get = lift State.get
155   put = lift . State.put
156
157 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
158  where
159   get = lift State.get
160   put = lift . State.put
161
162 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
163
164 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
165 runSession context state session = runReaderT (runStateT conduit state) context
166   where
167     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
168
169     handler (Unexpected "ConduitParser.empty") = do
170       lastMsg <- fromJust . lastReceivedMessage <$> get
171       name <- getParserName
172       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
173
174     handler e = throw e
175
176     chanSource = do
177       msg <- liftIO $ readChan (messageChan context)
178       yield msg
179       chanSource
180
181     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
182     watchdog = Conduit.awaitForever $ \msg -> do
183       curId <- curTimeoutId <$> get
184       case msg of
185         ServerMessage sMsg -> yield sMsg
186         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
187
188 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
189 -- It also does not automatically send initialize and exit messages.
190 runSessionWithHandles :: Handle -- ^ Server in
191                       -> Handle -- ^ Server out
192                       -> ProcessHandle -- ^ Server process
193                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
194                       -> SessionConfig
195                       -> ClientCapabilities
196                       -> FilePath -- ^ Root directory
197                       -> Session () -- ^ To exit the Server properly
198                       -> Session a
199                       -> IO a
200 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
201   absRootDir <- canonicalizePath rootDir
202
203   hSetBuffering serverIn  NoBuffering
204   hSetBuffering serverOut NoBuffering
205   -- This is required to make sure that we don’t get any
206   -- newline conversion or weird encoding issues.
207   hSetBinaryMode serverIn True
208   hSetBinaryMode serverOut True
209
210   reqMap <- newMVar newRequestMap
211   messageChan <- newChan
212   initRsp <- newEmptyMVar
213
214   mainThreadId <- myThreadId
215
216   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
217       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
218       runSession' = runSession context initState
219
220       errorHandler = throwTo mainThreadId :: SessionException -> IO()
221       serverListenerLauncher =
222         forkIO $ catch (serverHandler serverOut context) errorHandler
223       server = (Just serverIn, Just serverOut, Nothing, serverProc)
224       serverAndListenerFinalizer tid =
225         finally (timeout (messageTimeout config * 1000000)
226                          (runSession' exitServer))
227                 (cleanupProcess server >> killThread tid)
228
229   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
230                          (const $ runSession' session)
231   return result
232
233 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
234 updateStateC = awaitForever $ \msg -> do
235   updateState msg
236   yield msg
237
238 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
239 updateState (NotPublishDiagnostics n) = do
240   let List diags = n ^. params . diagnostics
241       doc = n ^. params . uri
242   modify (\s ->
243     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
244       in s { curDiagnostics = newDiags })
245
246 updateState (ReqApplyWorkspaceEdit r) = do
247
248   allChangeParams <- case r ^. params . edit . documentChanges of
249     Just (List cs) -> do
250       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
251       return $ map getParams cs
252     Nothing -> case r ^. params . edit . changes of
253       Just cs -> do
254         mapM_ checkIfNeedsOpened (HashMap.keys cs)
255         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
256       Nothing -> error "No changes!"
257
258   modifyM $ \s -> do
259     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
260     return $ s { vfs = newVFS }
261
262   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
263       mergedParams = map mergeParams groupedParams
264
265   -- TODO: Don't do this when replaying a session
266   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
267
268   -- Update VFS to new document versions
269   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
270       latestVersions = map ((^. textDocument) . last) sortedVersions
271       bumpedVersions = map (version . _Just +~ 1) latestVersions
272
273   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
274     modify $ \s ->
275       let oldVFS = vfs s
276           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
277           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
278       in s { vfs = newVFS }
279
280   where checkIfNeedsOpened uri = do
281           oldVFS <- vfs <$> get
282           ctx <- ask
283
284           -- if its not open, open it
285           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
286             let fp = fromJust $ uriToFilePath uri
287             contents <- liftIO $ T.readFile fp
288             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
289                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
290             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
291
292             modifyM $ \s -> do
293               newVFS <- liftIO $ openVFS (vfs s) msg
294               return $ s { vfs = newVFS }
295
296         getParams (TextDocumentEdit docId (List edits)) =
297           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
298             in DidChangeTextDocumentParams docId (List changeEvents)
299
300         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
301
302         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
303
304         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
305
306         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
307         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
308                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
309 updateState _ = return ()
310
311 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
312 sendMessage msg = do
313   h <- serverIn <$> ask
314   logMsg LogClient msg
315   liftIO $ B.hPut h (addHeader $ encode msg)
316
317 -- | Execute a block f that will throw a 'Timeout' exception
318 -- after duration seconds. This will override the global timeout
319 -- for waiting for messages to arrive defined in 'SessionConfig'.
320 withTimeout :: Int -> Session a -> Session a
321 withTimeout duration f = do
322   chan <- asks messageChan
323   timeoutId <- curTimeoutId <$> get
324   modify $ \s -> s { overridingTimeout = True }
325   liftIO $ forkIO $ do
326     threadDelay (duration * 1000000)
327     writeChan chan (TimeoutMessage timeoutId)
328   res <- f
329   modify $ \s -> s { curTimeoutId = timeoutId + 1,
330                      overridingTimeout = False
331                    }
332   return res
333
334 data LogMsgType = LogServer | LogClient
335   deriving Eq
336
337 -- | Logs the message if the config specified it
338 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
339        => LogMsgType -> a -> m ()
340 logMsg t msg = do
341   shouldLog <- asks $ logMessages . config
342   shouldColor <- asks $ logColor . config
343   liftIO $ when shouldLog $ do
344     when shouldColor $ setSGR [SetColor Foreground Dull color]
345     putStrLn $ arrow ++ showPretty msg
346     when shouldColor $ setSGR [Reset]
347
348   where arrow
349           | t == LogServer  = "<-- "
350           | otherwise       = "--> "
351         color
352           | t == LogServer  = Magenta
353           | otherwise       = Cyan
354
355         showPretty = B.unpack . encodePretty
356