Replace cleanupServer with functions avalilable in ghc <= 8.4
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68 import System.Process
69 import System.Timeout
70
71 -- | A session representing one instance of launching and connecting to a server.
72 --
73 -- You can send and receive messages to the server within 'Session' via
74 -- 'Language.Haskell.LSP.Test.message',
75 -- 'Language.Haskell.LSP.Test.sendRequest' and
76 -- 'Language.Haskell.LSP.Test.sendNotification'.
77
78 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
79
80 #if __GLASGOW_HASKELL__ >= 806
81 instance MonadFail Session where
82   fail s = do
83     lastMsg <- fromJust . lastReceivedMessage <$> get
84     liftIO $ throw (UnexpectedMessage s lastMsg)
85 #endif
86
87 -- | Stuff you can configure for a 'Session'.
88 data SessionConfig = SessionConfig
89   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
90   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
91   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
92   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
93   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
94   }
95
96 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
97 defaultConfig :: SessionConfig
98 defaultConfig = SessionConfig 60 False False True Nothing
99
100 instance Default SessionConfig where
101   def = defaultConfig
102
103 data SessionMessage = ServerMessage FromServerMessage
104                     | TimeoutMessage Int
105   deriving Show
106
107 data SessionContext = SessionContext
108   {
109     serverIn :: Handle
110   , rootDir :: FilePath
111   , messageChan :: Chan SessionMessage
112   , requestMap :: MVar RequestMap
113   , initRsp :: MVar InitializeResponse
114   , config :: SessionConfig
115   , sessionCapabilities :: ClientCapabilities
116   }
117
118 class Monad m => HasReader r m where
119   ask :: m r
120   asks :: (r -> b) -> m b
121   asks f = f <$> ask
122
123 instance Monad m => HasReader r (ParserStateReader a s r m) where
124   ask = lift $ lift Reader.ask
125
126 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
127   ask = lift $ lift Reader.ask
128
129 data SessionState = SessionState
130   {
131     curReqId :: LspId
132   , vfs :: VFS
133   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
134   , curTimeoutId :: Int
135   , overridingTimeout :: Bool
136   -- ^ The last received message from the server.
137   -- Used for providing exception information
138   , lastReceivedMessage :: Maybe FromServerMessage
139   }
140
141 class Monad m => HasState s m where
142   get :: m s
143
144   put :: s -> m ()
145
146   modify :: (s -> s) -> m ()
147   modify f = get >>= put . f
148
149   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
150   modifyM f = get >>= f >>= put
151
152 instance Monad m => HasState s (ParserStateReader a s r m) where
153   get = lift State.get
154   put = lift . State.put
155
156 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
157  where
158   get = lift State.get
159   put = lift . State.put
160
161 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
162
163 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
164 runSession context state session = runReaderT (runStateT conduit state) context
165   where
166     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
167
168     handler (Unexpected "ConduitParser.empty") = do
169       lastMsg <- fromJust . lastReceivedMessage <$> get
170       name <- getParserName
171       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
172
173     handler e = throw e
174
175     chanSource = do
176       msg <- liftIO $ readChan (messageChan context)
177       yield msg
178       chanSource
179
180     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
181     watchdog = Conduit.awaitForever $ \msg -> do
182       curId <- curTimeoutId <$> get
183       case msg of
184         ServerMessage sMsg -> yield sMsg
185         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
186
187 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
188 -- It also does not automatically send initialize and exit messages.
189 runSessionWithHandles :: Handle -- ^ Server in
190                       -> Handle -- ^ Server out
191                       -> ProcessHandle -- ^ Server process
192                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
193                       -> SessionConfig
194                       -> ClientCapabilities
195                       -> FilePath -- ^ Root directory
196                       -> Session () -- ^ To exit the Server properly
197                       -> Session a
198                       -> IO a
199 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
200   
201   absRootDir <- canonicalizePath rootDir
202
203   hSetBuffering serverIn  NoBuffering
204   hSetBuffering serverOut NoBuffering
205   -- This is required to make sure that we don’t get any
206   -- newline conversion or weird encoding issues.
207   hSetBinaryMode serverIn True
208   hSetBinaryMode serverOut True
209
210   reqMap <- newMVar newRequestMap
211   messageChan <- newChan
212   initRsp <- newEmptyMVar
213
214   mainThreadId <- myThreadId
215
216   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
217       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
218       runSession' = runSession context initState
219       
220       errorHandler = throwTo mainThreadId :: SessionException -> IO()
221       serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
222       server = (Just serverIn, Just serverOut, Nothing, serverProc)
223       serverFinalizer tid = finally (timeout (messageTimeout config * 1000000)
224                                              (runSession' exitServer))
225                                     (terminateProcess serverProc
226                                       >> hClose serverOut
227                                       >> killThread tid)
228       
229   (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
230   return result
231
232 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
233 updateStateC = awaitForever $ \msg -> do
234   updateState msg
235   yield msg
236
237 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
238 updateState (NotPublishDiagnostics n) = do
239   let List diags = n ^. params . diagnostics
240       doc = n ^. params . uri
241   modify (\s ->
242     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
243       in s { curDiagnostics = newDiags })
244
245 updateState (ReqApplyWorkspaceEdit r) = do
246
247   allChangeParams <- case r ^. params . edit . documentChanges of
248     Just (List cs) -> do
249       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
250       return $ map getParams cs
251     Nothing -> case r ^. params . edit . changes of
252       Just cs -> do
253         mapM_ checkIfNeedsOpened (HashMap.keys cs)
254         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
255       Nothing -> error "No changes!"
256
257   modifyM $ \s -> do
258     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
259     return $ s { vfs = newVFS }
260
261   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
262       mergedParams = map mergeParams groupedParams
263
264   -- TODO: Don't do this when replaying a session
265   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
266
267   -- Update VFS to new document versions
268   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
269       latestVersions = map ((^. textDocument) . last) sortedVersions
270       bumpedVersions = map (version . _Just +~ 1) latestVersions
271
272   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
273     modify $ \s ->
274       let oldVFS = vfs s
275           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
276           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
277       in s { vfs = newVFS }
278
279   where checkIfNeedsOpened uri = do
280           oldVFS <- vfs <$> get
281           ctx <- ask
282
283           -- if its not open, open it
284           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
285             let fp = fromJust $ uriToFilePath uri
286             contents <- liftIO $ T.readFile fp
287             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
288                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
289             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
290
291             modifyM $ \s -> do
292               newVFS <- liftIO $ openVFS (vfs s) msg
293               return $ s { vfs = newVFS }
294
295         getParams (TextDocumentEdit docId (List edits)) =
296           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
297             in DidChangeTextDocumentParams docId (List changeEvents)
298
299         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
300
301         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
302
303         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
304
305         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
306         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
307                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
308 updateState _ = return ()
309
310 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
311 sendMessage msg = do
312   h <- serverIn <$> ask
313   logMsg LogClient msg
314   liftIO $ B.hPut h (addHeader $ encode msg)
315
316 -- | Execute a block f that will throw a 'Timeout' exception
317 -- after duration seconds. This will override the global timeout
318 -- for waiting for messages to arrive defined in 'SessionConfig'.
319 withTimeout :: Int -> Session a -> Session a
320 withTimeout duration f = do
321   chan <- asks messageChan
322   timeoutId <- curTimeoutId <$> get
323   modify $ \s -> s { overridingTimeout = True }
324   liftIO $ forkIO $ do
325     threadDelay (duration * 1000000)
326     writeChan chan (TimeoutMessage timeoutId)
327   res <- f
328   modify $ \s -> s { curTimeoutId = timeoutId + 1,
329                      overridingTimeout = False
330                    }
331   return res
332
333 data LogMsgType = LogServer | LogClient
334   deriving Eq
335
336 -- | Logs the message if the config specified it
337 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
338        => LogMsgType -> a -> m ()
339 logMsg t msg = do
340   shouldLog <- asks $ logMessages . config
341   shouldColor <- asks $ logColor . config
342   liftIO $ when shouldLog $ do
343     when shouldColor $ setSGR [SetColor Foreground Dull color]
344     putStrLn $ arrow ++ showPretty msg
345     when shouldColor $ setSGR [Reset]
346
347   where arrow
348           | t == LogServer  = "<-- "
349           | otherwise       = "--> "
350         color
351           | t == LogServer  = Magenta
352           | otherwise       = Cyan
353
354         showPretty = B.unpack . encodePretty
355