Remove no longer needed dependency
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , logMsg
27   , LogMsgType(..)
28   )
29
30 where
31
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
36 import Control.Monad
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
41 #endif
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
51 import Data.Default
52 import Data.Foldable
53 import Data.List
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
58 import Data.Maybe
59 import Data.Function
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
70 import System.IO
71 import System.Process (ProcessHandle())
72 import System.Timeout
73
74 -- | A session representing one instance of launching and connecting to a server.
75 --
76 -- You can send and receive messages to the server within 'Session' via
77 -- 'Language.Haskell.LSP.Test.message',
78 -- 'Language.Haskell.LSP.Test.sendRequest' and
79 -- 'Language.Haskell.LSP.Test.sendNotification'.
80
81 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
82   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
83
84 #if __GLASGOW_HASKELL__ >= 806
85 instance MonadFail Session where
86   fail s = do
87     lastMsg <- fromJust . lastReceivedMessage <$> get
88     liftIO $ throw (UnexpectedMessage s lastMsg)
89 #endif
90
91 -- | Stuff you can configure for a 'Session'.
92 data SessionConfig = SessionConfig
93   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
94   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
95   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
96   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
97   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
98   }
99
100 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
101 defaultConfig :: SessionConfig
102 defaultConfig = SessionConfig 60 False False True Nothing
103
104 instance Default SessionConfig where
105   def = defaultConfig
106
107 data SessionMessage = ServerMessage FromServerMessage
108                     | TimeoutMessage Int
109   deriving Show
110
111 data SessionContext = SessionContext
112   {
113     serverIn :: Handle
114   , rootDir :: FilePath
115   , messageChan :: Chan SessionMessage
116   , requestMap :: MVar RequestMap
117   , initRsp :: MVar InitializeResponse
118   , config :: SessionConfig
119   , sessionCapabilities :: ClientCapabilities
120   }
121
122 class Monad m => HasReader r m where
123   ask :: m r
124   asks :: (r -> b) -> m b
125   asks f = f <$> ask
126
127 instance HasReader SessionContext Session where
128   ask  = Session (lift $ lift Reader.ask)
129
130 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
131   ask = lift $ lift Reader.ask
132
133 data SessionState = SessionState
134   {
135     curReqId :: LspId
136   , vfs :: VFS
137   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
138   , curTimeoutId :: Int
139   , overridingTimeout :: Bool
140   -- ^ The last received message from the server.
141   -- Used for providing exception information
142   , lastReceivedMessage :: Maybe FromServerMessage
143   }
144
145 class Monad m => HasState s m where
146   get :: m s
147
148   put :: s -> m ()
149
150   modify :: (s -> s) -> m ()
151   modify f = get >>= put . f
152
153   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
154   modifyM f = get >>= f >>= put
155
156 instance HasState SessionState Session where
157   get = Session (lift State.get)
158   put = Session . lift . State.put
159
160 instance Monad m => HasState s (ConduitM a b (StateT s m))
161  where
162   get = lift State.get
163   put = lift . State.put
164
165 instance Monad m => HasState s (ConduitParser a (StateT s m))
166  where
167   get = lift State.get
168   put = lift . State.put
169
170 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
171 runSession context state (Session session) = runReaderT (runStateT conduit state) context
172   where
173     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
174
175     handler (Unexpected "ConduitParser.empty") = do
176       lastMsg <- fromJust . lastReceivedMessage <$> get
177       name <- getParserName
178       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
179
180     handler e = throw e
181
182     chanSource = do
183       msg <- liftIO $ readChan (messageChan context)
184       yield msg
185       chanSource
186
187     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
188     watchdog = Conduit.awaitForever $ \msg -> do
189       curId <- curTimeoutId <$> get
190       case msg of
191         ServerMessage sMsg -> yield sMsg
192         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
193
194 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
195 -- It also does not automatically send initialize and exit messages.
196 runSessionWithHandles :: Handle -- ^ Server in
197                       -> Handle -- ^ Server out
198                       -> ProcessHandle -- ^ Server process
199                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
200                       -> SessionConfig
201                       -> ClientCapabilities
202                       -> FilePath -- ^ Root directory
203                       -> Session () -- ^ To exit the Server properly
204                       -> Session a
205                       -> IO a
206 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
207   absRootDir <- canonicalizePath rootDir
208
209   hSetBuffering serverIn  NoBuffering
210   hSetBuffering serverOut NoBuffering
211   -- This is required to make sure that we don’t get any
212   -- newline conversion or weird encoding issues.
213   hSetBinaryMode serverIn True
214   hSetBinaryMode serverOut True
215
216   reqMap <- newMVar newRequestMap
217   messageChan <- newChan
218   initRsp <- newEmptyMVar
219
220   mainThreadId <- myThreadId
221
222   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
223       initState vfs = SessionState (IdInt 0) vfs
224                                        mempty 0 False Nothing
225       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
226
227       errorHandler = throwTo mainThreadId :: SessionException -> IO()
228       serverListenerLauncher =
229         forkIO $ catch (serverHandler serverOut context) errorHandler
230       server = (Just serverIn, Just serverOut, Nothing, serverProc)
231       serverAndListenerFinalizer tid =
232         finally (timeout (messageTimeout config * 1000000)
233                          (runSession' exitServer))
234                 (cleanupProcess server >> killThread tid)
235
236   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
237                          (const $ runSession' session)
238   return result
239
240 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
241 updateStateC = awaitForever $ \msg -> do
242   updateState msg
243   yield msg
244
245 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
246             => FromServerMessage -> m ()
247 updateState (NotPublishDiagnostics n) = do
248   let List diags = n ^. params . diagnostics
249       doc = n ^. params . uri
250   modify (\s ->
251     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
252       in s { curDiagnostics = newDiags })
253
254 updateState (ReqApplyWorkspaceEdit r) = do
255
256   allChangeParams <- case r ^. params . edit . documentChanges of
257     Just (List cs) -> do
258       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
259       return $ map getParams cs
260     Nothing -> case r ^. params . edit . changes of
261       Just cs -> do
262         mapM_ checkIfNeedsOpened (HashMap.keys cs)
263         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
264       Nothing -> error "No changes!"
265
266   modifyM $ \s -> do
267     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
268     return $ s { vfs = newVFS }
269
270   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
271       mergedParams = map mergeParams groupedParams
272
273   -- TODO: Don't do this when replaying a session
274   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
275
276   -- Update VFS to new document versions
277   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
278       latestVersions = map ((^. textDocument) . last) sortedVersions
279       bumpedVersions = map (version . _Just +~ 1) latestVersions
280
281   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
282     modify $ \s ->
283       let oldVFS = vfs s
284           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
285           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
286       in s { vfs = newVFS }
287
288   where checkIfNeedsOpened uri = do
289           oldVFS <- vfs <$> get
290           ctx <- ask
291
292           -- if its not open, open it
293           unless (toNormalizedUri uri `Map.member` (vfsMap oldVFS)) $ do
294             let fp = fromJust $ uriToFilePath uri
295             contents <- liftIO $ T.readFile fp
296             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
297                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
298             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
299
300             modifyM $ \s -> do
301               let (newVFS,_) = openVFS (vfs s) msg
302               return $ s { vfs = newVFS }
303
304         getParams (TextDocumentEdit docId (List edits)) =
305           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
306             in DidChangeTextDocumentParams docId (List changeEvents)
307
308         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
309
310         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
311
312         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
313
314         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
315         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
316                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
317 updateState _ = return ()
318
319 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
320 sendMessage msg = do
321   h <- serverIn <$> ask
322   logMsg LogClient msg
323   liftIO $ B.hPut h (addHeader $ encode msg)
324
325 -- | Execute a block f that will throw a 'Timeout' exception
326 -- after duration seconds. This will override the global timeout
327 -- for waiting for messages to arrive defined in 'SessionConfig'.
328 withTimeout :: Int -> Session a -> Session a
329 withTimeout duration f = do
330   chan <- asks messageChan
331   timeoutId <- curTimeoutId <$> get
332   modify $ \s -> s { overridingTimeout = True }
333   liftIO $ forkIO $ do
334     threadDelay (duration * 1000000)
335     writeChan chan (TimeoutMessage timeoutId)
336   res <- f
337   modify $ \s -> s { curTimeoutId = timeoutId + 1,
338                      overridingTimeout = False
339                    }
340   return res
341
342 data LogMsgType = LogServer | LogClient
343   deriving Eq
344
345 -- | Logs the message if the config specified it
346 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
347        => LogMsgType -> a -> m ()
348 logMsg t msg = do
349   shouldLog <- asks $ logMessages . config
350   shouldColor <- asks $ logColor . config
351   liftIO $ when shouldLog $ do
352     when shouldColor $ setSGR [SetColor Foreground Dull color]
353     putStrLn $ arrow ++ showPretty msg
354     when shouldColor $ setSGR [Reset]
355
356   where arrow
357           | t == LogServer  = "<-- "
358           | otherwise       = "--> "
359         color
360           | t == LogServer  = Magenta
361           | otherwise       = Cyan
362
363         showPretty = B.unpack . encodePretty
364