Start off on new version in didChanges from updateState
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , getCurTimeoutId
27   , bumpTimeoutId
28   , logMsg
29   , LogMsgType(..)
30   )
31
32 where
33
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
38 import Control.Monad
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
43 #endif
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
49 import Data.Aeson
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
53 import Data.Default
54 import Data.Foldable
55 import Data.List
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
60 import Data.Maybe
61 import Data.Function
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
73 import System.IO
74 import System.Process (waitForProcess, ProcessHandle())
75 import System.Timeout
76
77 -- | A session representing one instance of launching and connecting to a server.
78 --
79 -- You can send and receive messages to the server within 'Session' via
80 -- 'Language.Haskell.LSP.Test.message',
81 -- 'Language.Haskell.LSP.Test.sendRequest' and
82 -- 'Language.Haskell.LSP.Test.sendNotification'.
83
84 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
85   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
86
87 #if __GLASGOW_HASKELL__ >= 806
88 instance MonadFail Session where
89   fail s = do
90     lastMsg <- fromJust . lastReceivedMessage <$> get
91     liftIO $ throw (UnexpectedMessage s lastMsg)
92 #endif
93
94 -- | Stuff you can configure for a 'Session'.
95 data SessionConfig = SessionConfig
96   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
97   , logStdErr      :: Bool
98   -- ^ Redirect the server's stderr to this stdout, defaults to False.
99   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
100   , logMessages    :: Bool
101   -- ^ Trace the messages sent and received to stdout, defaults to False.
102   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
103   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
104   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
105   , ignoreLogNotifications :: Bool
106   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
107   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
108   --
109   -- @since 0.9.0.0
110   }
111
112 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
113 defaultConfig :: SessionConfig
114 defaultConfig = SessionConfig 60 False False True Nothing False
115
116 instance Default SessionConfig where
117   def = defaultConfig
118
119 data SessionMessage = ServerMessage FromServerMessage
120                     | TimeoutMessage Int
121   deriving Show
122
123 data SessionContext = SessionContext
124   {
125     serverIn :: Handle
126   , rootDir :: FilePath
127   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
128   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
129   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
130   , requestMap :: MVar RequestMap
131   , initRsp :: MVar InitializeResponse
132   , config :: SessionConfig
133   , sessionCapabilities :: ClientCapabilities
134   }
135
136 class Monad m => HasReader r m where
137   ask :: m r
138   asks :: (r -> b) -> m b
139   asks f = f <$> ask
140
141 instance HasReader SessionContext Session where
142   ask  = Session (lift $ lift Reader.ask)
143
144 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
145   ask = lift $ lift Reader.ask
146
147 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
148 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
149
150 -- Pass this the timeoutid you *were* waiting on
151 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
152 bumpTimeoutId prev = do
153   v <- asks curTimeoutId
154   -- when updating the curtimeoutid, account for the fact that something else
155   -- might have bumped the timeoutid in the meantime
156   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
157
158 data SessionState = SessionState
159   {
160     curReqId :: LspId
161   , vfs :: VFS
162   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
163   , overridingTimeout :: Bool
164   -- ^ The last received message from the server.
165   -- Used for providing exception information
166   , lastReceivedMessage :: Maybe FromServerMessage
167   , curDynCaps :: Map.Map T.Text Registration
168   -- ^ The capabilities that the server has dynamically registered with us so
169   -- far
170   }
171
172 class Monad m => HasState s m where
173   get :: m s
174
175   put :: s -> m ()
176
177   modify :: (s -> s) -> m ()
178   modify f = get >>= put . f
179
180   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
181   modifyM f = get >>= f >>= put
182
183 instance HasState SessionState Session where
184   get = Session (lift State.get)
185   put = Session . lift . State.put
186
187 instance Monad m => HasState s (StateT s m) where
188   get = State.get
189   put = State.put
190
191 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
192  where
193   get = lift get
194   put = lift . put
195
196 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
197  where
198   get = lift get
199   put = lift . put
200
201 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
202 runSession context state (Session session) = runReaderT (runStateT conduit state) context
203   where
204     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
205
206     handler (Unexpected "ConduitParser.empty") = do
207       lastMsg <- fromJust . lastReceivedMessage <$> get
208       name <- getParserName
209       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
210
211     handler e = throw e
212
213     chanSource = do
214       msg <- liftIO $ readChan (messageChan context)
215       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
216         yield msg
217       chanSource
218
219     isLogNotification (ServerMessage (NotShowMessage _)) = True
220     isLogNotification (ServerMessage (NotLogMessage _)) = True
221     isLogNotification _ = False
222
223     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
224     watchdog = Conduit.awaitForever $ \msg -> do
225       curId <- getCurTimeoutId
226       case msg of
227         ServerMessage sMsg -> yield sMsg
228         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
229
230 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
231 -- It also does not automatically send initialize and exit messages.
232 runSessionWithHandles :: Handle -- ^ Server in
233                       -> Handle -- ^ Server out
234                       -> ProcessHandle -- ^ Server process
235                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
236                       -> SessionConfig
237                       -> ClientCapabilities
238                       -> FilePath -- ^ Root directory
239                       -> Session () -- ^ To exit the Server properly
240                       -> Session a
241                       -> IO a
242 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
243   absRootDir <- canonicalizePath rootDir
244
245   hSetBuffering serverIn  NoBuffering
246   hSetBuffering serverOut NoBuffering
247   -- This is required to make sure that we don’t get any
248   -- newline conversion or weird encoding issues.
249   hSetBinaryMode serverIn True
250   hSetBinaryMode serverOut True
251
252   reqMap <- newMVar newRequestMap
253   messageChan <- newChan
254   timeoutIdVar <- newMVar 0
255   initRsp <- newEmptyMVar
256
257   mainThreadId <- myThreadId
258
259   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
260       initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty
261       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
262
263       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
264       serverListenerLauncher =
265         forkIO $ catch (serverHandler serverOut context) errorHandler
266       server = (Just serverIn, Just serverOut, Nothing, serverProc)
267       msgTimeoutMs = messageTimeout config * 10^6
268       serverAndListenerFinalizer tid = do
269         finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
270           -- Make sure to kill the listener first, before closing
271           -- handles etc via cleanupProcess
272           killThread tid
273           -- Give the server some time to exit cleanly
274           timeout msgTimeoutMs (waitForProcess serverProc)
275           cleanupProcess server
276
277   (result, _) <- bracket serverListenerLauncher
278                          serverAndListenerFinalizer
279                          (const $ runSession' session)
280   return result
281
282 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
283 updateStateC = awaitForever $ \msg -> do
284   updateState msg
285   yield msg
286
287 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
288             => FromServerMessage -> m ()
289
290 -- Keep track of dynamic capability registration
291 updateState (ReqRegisterCapability req) = do
292   let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
293   modify $ \s ->
294     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
295
296 updateState (ReqUnregisterCapability req) = do
297   let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
298   modify $ \s ->
299     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
300     in s { curDynCaps = newCurDynCaps }
301
302 updateState (NotPublishDiagnostics n) = do
303   let List diags = n ^. params . diagnostics
304       doc = n ^. params . uri
305   modify $ \s ->
306     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
307       in s { curDiagnostics = newDiags }
308
309 updateState (ReqApplyWorkspaceEdit r) = do
310
311   -- First, prefer the versioned documentChanges field
312   allChangeParams <- case r ^. params . edit . documentChanges of
313     Just (List cs) -> do
314       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
315       return $ map getParams cs
316     -- Then fall back to the changes field
317     Nothing -> case r ^. params . edit . changes of
318       Just cs -> do
319         mapM_ checkIfNeedsOpened (HashMap.keys cs)
320         concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
321       Nothing ->
322         error "WorkspaceEdit contains neither documentChanges nor changes!"
323
324   modifyM $ \s -> do
325     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
326     return $ s { vfs = newVFS }
327
328   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
329       mergedParams = map mergeParams groupedParams
330
331   -- TODO: Don't do this when replaying a session
332   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
333
334   -- Update VFS to new document versions
335   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
336       latestVersions = map ((^. textDocument) . last) sortedVersions
337       bumpedVersions = map (version . _Just +~ 1) latestVersions
338
339   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
340     modify $ \s ->
341       let oldVFS = vfs s
342           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
343           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
344       in s { vfs = newVFS }
345
346   where checkIfNeedsOpened uri = do
347           oldVFS <- vfs <$> get
348           ctx <- ask
349
350           -- if its not open, open it
351           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
352             let fp = fromJust $ uriToFilePath uri
353             contents <- liftIO $ T.readFile fp
354             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
355                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
356             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
357
358             modifyM $ \s -> do
359               let (newVFS,_) = openVFS (vfs s) msg
360               return $ s { vfs = newVFS }
361
362         getParams (TextDocumentEdit docId (List edits)) =
363           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
364             in DidChangeTextDocumentParams docId (List changeEvents)
365
366         -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
367         -- where n is the current version
368         textDocumentVersions uri = do
369           m <- vfsMap . vfs <$> get
370           let curVer = fromMaybe 0 $
371                 _lsp_version <$> m Map.!? (toNormalizedUri uri)
372           pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
373
374         textDocumentEdits uri edits = do
375           vers <- textDocumentVersions uri
376           pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
377
378         getChangeParams uri (List edits) =
379           map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
380
381         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
382         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
383                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
384 updateState _ = return ()
385
386 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
387 sendMessage msg = do
388   h <- serverIn <$> ask
389   logMsg LogClient msg
390   liftIO $ B.hPut h (addHeader $ encode msg)
391
392 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
393 -- after duration seconds. This will override the global timeout
394 -- for waiting for messages to arrive defined in 'SessionConfig'.
395 withTimeout :: Int -> Session a -> Session a
396 withTimeout duration f = do
397   chan <- asks messageChan
398   timeoutId <- getCurTimeoutId
399   modify $ \s -> s { overridingTimeout = True }
400   liftIO $ forkIO $ do
401     threadDelay (duration * 1000000)
402     writeChan chan (TimeoutMessage timeoutId)
403   res <- f
404   bumpTimeoutId timeoutId
405   modify $ \s -> s { overridingTimeout = False }
406   return res
407
408 data LogMsgType = LogServer | LogClient
409   deriving Eq
410
411 -- | Logs the message if the config specified it
412 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
413        => LogMsgType -> a -> m ()
414 logMsg t msg = do
415   shouldLog <- asks $ logMessages . config
416   shouldColor <- asks $ logColor . config
417   liftIO $ when shouldLog $ do
418     when shouldColor $ setSGR [SetColor Foreground Dull color]
419     putStrLn $ arrow ++ showPretty msg
420     when shouldColor $ setSGR [Reset]
421
422   where arrow
423           | t == LogServer  = "<-- "
424           | otherwise       = "--> "
425         color
426           | t == LogServer  = Magenta
427           | otherwise       = Cyan
428
429         showPretty = B.unpack . encodePretty