Track upstream
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , logMsg
27   , LogMsgType(..)
28   )
29
30 where
31
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
36 import Control.Monad
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
41 #endif
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
51 import Data.Default
52 import Data.Foldable
53 import Data.List
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
58 import Data.Maybe
59 import Data.Function
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
70 import System.IO
71 import System.Process (ProcessHandle())
72 import System.Timeout
73 import System.IO.Temp
74
75 -- | A session representing one instance of launching and connecting to a server.
76 --
77 -- You can send and receive messages to the server within 'Session' via
78 -- 'Language.Haskell.LSP.Test.message',
79 -- 'Language.Haskell.LSP.Test.sendRequest' and
80 -- 'Language.Haskell.LSP.Test.sendNotification'.
81
82 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
83   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
84
85 #if __GLASGOW_HASKELL__ >= 806
86 instance MonadFail Session where
87   fail s = do
88     lastMsg <- fromJust . lastReceivedMessage <$> get
89     liftIO $ throw (UnexpectedMessage s lastMsg)
90 #endif
91
92 -- | Stuff you can configure for a 'Session'.
93 data SessionConfig = SessionConfig
94   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
95   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
96   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
97   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
98   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
99   }
100
101 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
102 defaultConfig :: SessionConfig
103 defaultConfig = SessionConfig 60 False False True Nothing
104
105 instance Default SessionConfig where
106   def = defaultConfig
107
108 data SessionMessage = ServerMessage FromServerMessage
109                     | TimeoutMessage Int
110   deriving Show
111
112 data SessionContext = SessionContext
113   {
114     serverIn :: Handle
115   , rootDir :: FilePath
116   , messageChan :: Chan SessionMessage
117   , requestMap :: MVar RequestMap
118   , initRsp :: MVar InitializeResponse
119   , config :: SessionConfig
120   , sessionCapabilities :: ClientCapabilities
121   }
122
123 class Monad m => HasReader r m where
124   ask :: m r
125   asks :: (r -> b) -> m b
126   asks f = f <$> ask
127
128 instance HasReader SessionContext Session where
129   ask  = Session (lift $ lift Reader.ask)
130
131 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
132   ask = lift $ lift Reader.ask
133
134 data SessionState = SessionState
135   {
136     curReqId :: LspId
137   , vfs :: VFS
138   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
139   , curTimeoutId :: Int
140   , overridingTimeout :: Bool
141   -- ^ The last received message from the server.
142   -- Used for providing exception information
143   , lastReceivedMessage :: Maybe FromServerMessage
144   }
145
146 class Monad m => HasState s m where
147   get :: m s
148
149   put :: s -> m ()
150
151   modify :: (s -> s) -> m ()
152   modify f = get >>= put . f
153
154   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
155   modifyM f = get >>= f >>= put
156
157 instance HasState SessionState Session where
158   get = Session (lift State.get)
159   put = Session . lift . State.put
160
161 instance Monad m => HasState s (ConduitM a b (StateT s m))
162  where
163   get = lift State.get
164   put = lift . State.put
165
166 instance Monad m => HasState s (ConduitParser a (StateT s m))
167  where
168   get = lift State.get
169   put = lift . State.put
170
171 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
172 runSession context state (Session session) = runReaderT (runStateT conduit state) context
173   where
174     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
175
176     handler (Unexpected "ConduitParser.empty") = do
177       lastMsg <- fromJust . lastReceivedMessage <$> get
178       name <- getParserName
179       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
180
181     handler e = throw e
182
183     chanSource = do
184       msg <- liftIO $ readChan (messageChan context)
185       yield msg
186       chanSource
187
188     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
189     watchdog = Conduit.awaitForever $ \msg -> do
190       curId <- curTimeoutId <$> get
191       case msg of
192         ServerMessage sMsg -> yield sMsg
193         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
194
195 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
196 -- It also does not automatically send initialize and exit messages.
197 runSessionWithHandles :: Handle -- ^ Server in
198                       -> Handle -- ^ Server out
199                       -> ProcessHandle -- ^ Server process
200                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
201                       -> SessionConfig
202                       -> ClientCapabilities
203                       -> FilePath -- ^ Root directory
204                       -> Session () -- ^ To exit the Server properly
205                       -> Session a
206                       -> IO a
207 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
208   absRootDir <- canonicalizePath rootDir
209
210   hSetBuffering serverIn  NoBuffering
211   hSetBuffering serverOut NoBuffering
212   -- This is required to make sure that we don’t get any
213   -- newline conversion or weird encoding issues.
214   hSetBinaryMode serverIn True
215   hSetBinaryMode serverOut True
216
217   reqMap <- newMVar newRequestMap
218   messageChan <- newChan
219   initRsp <- newEmptyMVar
220
221   mainThreadId <- myThreadId
222
223   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
224       initState tmp_dir = SessionState (IdInt 0) (VFS mempty tmp_dir)
225                                        mempty 0 False Nothing
226       runSession' ses = withSystemTempDirectory "lsp-test" $ \tmp_dir ->
227                       runSession context (initState tmp_dir) ses
228
229       errorHandler = throwTo mainThreadId :: SessionException -> IO()
230       serverListenerLauncher =
231         forkIO $ catch (serverHandler serverOut context) errorHandler
232       server = (Just serverIn, Just serverOut, Nothing, serverProc)
233       serverAndListenerFinalizer tid =
234         finally (timeout (messageTimeout config * 1000000)
235                          (runSession' exitServer))
236                 (cleanupProcess server >> killThread tid)
237
238   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
239                          (const $ runSession' session)
240   return result
241
242 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
243 updateStateC = awaitForever $ \msg -> do
244   updateState msg
245   yield msg
246
247 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
248             => FromServerMessage -> m ()
249 updateState (NotPublishDiagnostics n) = do
250   let List diags = n ^. params . diagnostics
251       doc = n ^. params . uri
252   modify (\s ->
253     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
254       in s { curDiagnostics = newDiags })
255
256 updateState (ReqApplyWorkspaceEdit r) = do
257
258   allChangeParams <- case r ^. params . edit . documentChanges of
259     Just (List cs) -> do
260       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
261       return $ map getParams cs
262     Nothing -> case r ^. params . edit . changes of
263       Just cs -> do
264         mapM_ checkIfNeedsOpened (HashMap.keys cs)
265         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
266       Nothing -> error "No changes!"
267
268   modifyM $ \s -> do
269     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
270     return $ s { vfs = newVFS }
271
272   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
273       mergedParams = map mergeParams groupedParams
274
275   -- TODO: Don't do this when replaying a session
276   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
277
278   -- Update VFS to new document versions
279   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
280       latestVersions = map ((^. textDocument) . last) sortedVersions
281       bumpedVersions = map (version . _Just +~ 1) latestVersions
282
283   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
284     modify $ \s ->
285       let oldVFS = vfs s
286           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
287           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
288       in s { vfs = newVFS }
289
290   where checkIfNeedsOpened uri = do
291           oldVFS <- vfs <$> get
292           ctx <- ask
293
294           -- if its not open, open it
295           unless (toNormalizedUri uri `Map.member` (vfsMap oldVFS)) $ do
296             let fp = fromJust $ uriToFilePath uri
297             contents <- liftIO $ T.readFile fp
298             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
299                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
300             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
301
302             modifyM $ \s -> do
303               newVFS <- liftIO $ openVFS (vfs s) msg
304               return $ s { vfs = newVFS }
305
306         getParams (TextDocumentEdit docId (List edits)) =
307           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
308             in DidChangeTextDocumentParams docId (List changeEvents)
309
310         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
311
312         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
313
314         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
315
316         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
317         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
318                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
319 updateState _ = return ()
320
321 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
322 sendMessage msg = do
323   h <- serverIn <$> ask
324   logMsg LogClient msg
325   liftIO $ B.hPut h (addHeader $ encode msg)
326
327 -- | Execute a block f that will throw a 'Timeout' exception
328 -- after duration seconds. This will override the global timeout
329 -- for waiting for messages to arrive defined in 'SessionConfig'.
330 withTimeout :: Int -> Session a -> Session a
331 withTimeout duration f = do
332   chan <- asks messageChan
333   timeoutId <- curTimeoutId <$> get
334   modify $ \s -> s { overridingTimeout = True }
335   liftIO $ forkIO $ do
336     threadDelay (duration * 1000000)
337     writeChan chan (TimeoutMessage timeoutId)
338   res <- f
339   modify $ \s -> s { curTimeoutId = timeoutId + 1,
340                      overridingTimeout = False
341                    }
342   return res
343
344 data LogMsgType = LogServer | LogClient
345   deriving Eq
346
347 -- | Logs the message if the config specified it
348 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
349        => LogMsgType -> a -> m ()
350 logMsg t msg = do
351   shouldLog <- asks $ logMessages . config
352   shouldColor <- asks $ logColor . config
353   liftIO $ when shouldLog $ do
354     when shouldColor $ setSGR [SetColor Foreground Dull color]
355     putStrLn $ arrow ++ showPretty msg
356     when shouldColor $ setSGR [Reset]
357
358   where arrow
359           | t == LogServer  = "<-- "
360           | otherwise       = "--> "
361         color
362           | t == LogServer  = Magenta
363           | otherwise       = Cyan
364
365         showPretty = B.unpack . encodePretty
366