Export satisfy
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68
69 -- | A session representing one instance of launching and connecting to a server.
70 --
71 -- You can send and receive messages to the server within 'Session' via
72 -- 'Language.Haskell.LSP.Test.message',
73 -- 'Language.Haskell.LSP.Test.sendRequest' and
74 -- 'Language.Haskell.LSP.Test.sendNotification'.
75
76 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
77
78 #if __GLASGOW_HASKELL__ >= 806
79 instance MonadFail Session where
80   fail s = do
81     lastMsg <- fromJust . lastReceivedMessage <$> get
82     liftIO $ throw (UnexpectedMessage s lastMsg)
83 #endif
84
85 -- | Stuff you can configure for a 'Session'.
86 data SessionConfig = SessionConfig
87   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
88   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
89   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
90   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
91   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
92   }
93
94 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
95 defaultConfig :: SessionConfig
96 defaultConfig = SessionConfig 60 False False True Nothing
97
98 instance Default SessionConfig where
99   def = defaultConfig
100
101 data SessionMessage = ServerMessage FromServerMessage
102                     | TimeoutMessage Int
103   deriving Show
104
105 data SessionContext = SessionContext
106   {
107     serverIn :: Handle
108   , rootDir :: FilePath
109   , messageChan :: Chan SessionMessage
110   , requestMap :: MVar RequestMap
111   , initRsp :: MVar InitializeResponse
112   , config :: SessionConfig
113   , sessionCapabilities :: ClientCapabilities
114   }
115
116 class Monad m => HasReader r m where
117   ask :: m r
118   asks :: (r -> b) -> m b
119   asks f = f <$> ask
120
121 instance Monad m => HasReader r (ParserStateReader a s r m) where
122   ask = lift $ lift Reader.ask
123
124 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
125   ask = lift $ lift Reader.ask
126
127 data SessionState = SessionState
128   {
129     curReqId :: LspId
130   , vfs :: VFS
131   , curDiagnostics :: Map.Map Uri [Diagnostic]
132   , curTimeoutId :: Int
133   , overridingTimeout :: Bool
134   -- ^ The last received message from the server.
135   -- Used for providing exception information
136   , lastReceivedMessage :: Maybe FromServerMessage
137   }
138
139 class Monad m => HasState s m where
140   get :: m s
141
142   put :: s -> m ()
143
144   modify :: (s -> s) -> m ()
145   modify f = get >>= put . f
146
147   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
148   modifyM f = get >>= f >>= put
149
150 instance Monad m => HasState s (ParserStateReader a s r m) where
151   get = lift State.get
152   put = lift . State.put
153
154 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
155  where
156   get = lift State.get
157   put = lift . State.put
158
159 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
160
161 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
162 runSession context state session = runReaderT (runStateT conduit state) context
163   where
164     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
165
166     handler (Unexpected "ConduitParser.empty") = do
167       lastMsg <- fromJust . lastReceivedMessage <$> get
168       name <- getParserName
169       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
170
171     handler e = throw e
172
173     chanSource = do
174       msg <- liftIO $ readChan (messageChan context)
175       yield msg
176       chanSource
177
178     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
179     watchdog = Conduit.awaitForever $ \msg -> do
180       curId <- curTimeoutId <$> get
181       case msg of
182         ServerMessage sMsg -> yield sMsg
183         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
184
185 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
186 -- It also does not automatically send initialize and exit messages.
187 runSessionWithHandles :: Handle -- ^ Server in
188                       -> Handle -- ^ Server out
189                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
190                       -> SessionConfig
191                       -> ClientCapabilities
192                       -> FilePath -- ^ Root directory
193                       -> Session a
194                       -> IO a
195 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
196   absRootDir <- canonicalizePath rootDir
197
198   hSetBuffering serverIn  NoBuffering
199   hSetBuffering serverOut NoBuffering
200
201   reqMap <- newMVar newRequestMap
202   messageChan <- newChan
203   initRsp <- newEmptyMVar
204
205   mainThreadId <- myThreadId
206
207   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
208       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
209       launchServerHandler = forkIO $ catch (serverHandler serverOut context)
210                                            (throwTo mainThreadId :: SessionException -> IO ())
211   (result, _) <- bracket launchServerHandler killThread $
212     const $ runSession context initState session
213
214   return result
215
216 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
217 updateStateC = awaitForever $ \msg -> do
218   updateState msg
219   yield msg
220
221 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
222 updateState (NotPublishDiagnostics n) = do
223   let List diags = n ^. params . diagnostics
224       doc = n ^. params . uri
225   modify (\s ->
226     let newDiags = Map.insert doc diags (curDiagnostics s)
227       in s { curDiagnostics = newDiags })
228
229 updateState (ReqApplyWorkspaceEdit r) = do
230
231   allChangeParams <- case r ^. params . edit . documentChanges of
232     Just (List cs) -> do
233       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
234       return $ map getParams cs
235     Nothing -> case r ^. params . edit . changes of
236       Just cs -> do
237         mapM_ checkIfNeedsOpened (HashMap.keys cs)
238         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
239       Nothing -> error "No changes!"
240
241   modifyM $ \s -> do
242     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
243     return $ s { vfs = newVFS }
244
245   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
246       mergedParams = map mergeParams groupedParams
247
248   -- TODO: Don't do this when replaying a session
249   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
250
251   -- Update VFS to new document versions
252   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
253       latestVersions = map ((^. textDocument) . last) sortedVersions
254       bumpedVersions = map (version . _Just +~ 1) latestVersions
255
256   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
257     modify $ \s ->
258       let oldVFS = vfs s
259           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
260           newVFS = Map.adjust update uri oldVFS
261       in s { vfs = newVFS }
262
263   where checkIfNeedsOpened uri = do
264           oldVFS <- vfs <$> get
265           ctx <- ask
266
267           -- if its not open, open it
268           unless (uri `Map.member` oldVFS) $ do
269             let fp = fromJust $ uriToFilePath uri
270             contents <- liftIO $ T.readFile fp
271             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
272                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
273             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
274
275             modifyM $ \s -> do
276               newVFS <- liftIO $ openVFS (vfs s) msg
277               return $ s { vfs = newVFS }
278
279         getParams (TextDocumentEdit docId (List edits)) =
280           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
281             in DidChangeTextDocumentParams docId (List changeEvents)
282
283         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
284
285         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
286
287         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
288
289         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
290         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
291                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
292 updateState _ = return ()
293
294 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
295 sendMessage msg = do
296   h <- serverIn <$> ask
297   logMsg LogClient msg
298   liftIO $ B.hPut h (addHeader $ encode msg)
299
300 -- | Execute a block f that will throw a 'Timeout' exception
301 -- after duration seconds. This will override the global timeout
302 -- for waiting for messages to arrive defined in 'SessionConfig'.
303 withTimeout :: Int -> Session a -> Session a
304 withTimeout duration f = do
305   chan <- asks messageChan
306   timeoutId <- curTimeoutId <$> get
307   modify $ \s -> s { overridingTimeout = True }
308   liftIO $ forkIO $ do
309     threadDelay (duration * 1000000)
310     writeChan chan (TimeoutMessage timeoutId)
311   res <- f
312   modify $ \s -> s { curTimeoutId = timeoutId + 1,
313                      overridingTimeout = False
314                    }
315   return res
316
317 data LogMsgType = LogServer | LogClient
318   deriving Eq
319
320 -- | Logs the message if the config specified it
321 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
322        => LogMsgType -> a -> m ()
323 logMsg t msg = do
324   shouldLog <- asks $ logMessages . config
325   shouldColor <- asks $ logColor . config
326   liftIO $ when shouldLog $ do
327     when shouldColor $ setSGR [SetColor Foreground Dull color]
328     putStrLn $ arrow ++ showPretty msg
329     when shouldColor $ setSGR [Reset]
330
331   where arrow
332           | t == LogServer  = "<-- "
333           | otherwise       = "--> "
334         color
335           | t == LogServer  = Magenta
336           | otherwise       = Cyan
337
338         showPretty = B.unpack . encodePretty
339