dc3c6ff870cccbb488b490b896b6d7bdab9f58f1
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.Fail
36 import Control.Monad.IO.Class
37 import Control.Monad.Except
38 #if __GLASGOW_HASKELL__ >= 806
39 import qualified Control.Monad.Fail as Fail
40 #endif
41 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
42 import qualified Control.Monad.Trans.Reader as Reader (ask)
43 import Control.Monad.Trans.State (StateT, runStateT)
44 import qualified Control.Monad.Trans.State as State (get, put)
45 import qualified Data.ByteString.Lazy.Char8 as B
46 import Data.Aeson
47 import Data.Aeson.Encode.Pretty
48 import Data.Conduit as Conduit
49 import Data.Conduit.Parser as Parser
50 import Data.Default
51 import Data.Foldable
52 import Data.List
53 import qualified Data.Map as Map
54 import qualified Data.Text as T
55 import qualified Data.Text.IO as T
56 import qualified Data.HashMap.Strict as HashMap
57 import Data.Maybe
58 import Data.Function
59 import Language.Haskell.LSP.Messages
60 import Language.Haskell.LSP.Types.Capabilities
61 import Language.Haskell.LSP.Types
62 import Language.Haskell.LSP.Types.Lens hiding (error)
63 import Language.Haskell.LSP.VFS
64 import Language.Haskell.LSP.Test.Decoding
65 import Language.Haskell.LSP.Test.Exceptions
66 import System.Console.ANSI
67 import System.Directory
68 import System.IO
69
70 -- | A session representing one instance of launching and connecting to a server.
71 -- 
72 -- You can send and receive messages to the server within 'Session' via 'getMessage',
73 -- 'sendRequest' and 'sendNotification'.
74 --
75
76 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
77
78 #if __GLASGOW_HASKELL__ >= 806
79 instance MonadFail Session where
80   fail s = do
81     lastMsg <- fromJust . lastReceivedMessage <$> get
82     liftIO $ throw (UnexpectedMessage s lastMsg)
83 #endif
84
85 -- | Stuff you can configure for a 'Session'.
86 data SessionConfig = SessionConfig
87   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
88   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
89   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
90   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
91   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
92   }
93
94 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
95 defaultConfig :: SessionConfig
96 defaultConfig = SessionConfig 60 False True True Nothing
97
98 instance Default SessionConfig where
99   def = defaultConfig
100
101 data SessionMessage = ServerMessage FromServerMessage
102                     | TimeoutMessage Int
103   deriving Show
104
105 data SessionContext = SessionContext
106   {
107     serverIn :: Handle
108   , rootDir :: FilePath
109   , messageChan :: Chan SessionMessage
110   , requestMap :: MVar RequestMap
111   , initRsp :: MVar InitializeResponse
112   , config :: SessionConfig
113   , sessionCapabilities :: ClientCapabilities
114   }
115
116 class Monad m => HasReader r m where
117   ask :: m r
118   asks :: (r -> b) -> m b
119   asks f = f <$> ask
120
121 instance Monad m => HasReader r (ParserStateReader a s r m) where
122   ask = lift $ lift Reader.ask
123
124 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
125   ask = lift $ lift Reader.ask
126
127 data SessionState = SessionState
128   {
129     curReqId :: LspId
130   , vfs :: VFS
131   , curDiagnostics :: Map.Map Uri [Diagnostic]
132   , curTimeoutId :: Int
133   , overridingTimeout :: Bool
134   -- ^ The last received message from the server.
135   -- Used for providing exception information
136   , lastReceivedMessage :: Maybe FromServerMessage
137   }
138
139 class Monad m => HasState s m where
140   get :: m s
141
142   put :: s -> m ()
143
144   modify :: (s -> s) -> m ()
145   modify f = get >>= put . f
146
147   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
148   modifyM f = get >>= f >>= put
149
150 instance Monad m => HasState s (ParserStateReader a s r m) where
151   get = lift State.get
152   put = lift . State.put
153
154 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
155  where
156   get = lift State.get
157   put = lift . State.put
158
159 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
160
161 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
162 runSession context state session = runReaderT (runStateT conduit state) context
163   where
164     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
165
166     handler (Unexpected "ConduitParser.empty") = do
167       lastMsg <- fromJust . lastReceivedMessage <$> get
168       name <- getParserName
169       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
170
171     handler e = throw e
172
173     chanSource = do
174       msg <- liftIO $ readChan (messageChan context)
175       yield msg
176       chanSource
177
178     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
179     watchdog = Conduit.awaitForever $ \msg -> do
180       curId <- curTimeoutId <$> get
181       case msg of
182         ServerMessage sMsg -> yield sMsg
183         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
184
185 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
186 -- It also does not automatically send initialize and exit messages.
187 runSessionWithHandles :: Handle -- ^ Server in
188                       -> Handle -- ^ Server out
189                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
190                       -> SessionConfig
191                       -> ClientCapabilities
192                       -> FilePath -- ^ Root directory
193                       -> Session a
194                       -> IO a
195 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
196   absRootDir <- canonicalizePath rootDir
197
198   hSetBuffering serverIn  NoBuffering
199   hSetBuffering serverOut NoBuffering
200
201   reqMap <- newMVar newRequestMap
202   messageChan <- newChan
203   initRsp <- newEmptyMVar
204
205   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
206       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
207
208   threadId <- forkIO $ void $ serverHandler serverOut context
209   (result, _) <- runSession context initState session
210
211   killThread threadId
212
213   return result
214
215 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
216 updateStateC = awaitForever $ \msg -> do
217   updateState msg
218   yield msg
219
220 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
221 updateState (NotPublishDiagnostics n) = do
222   let List diags = n ^. params . diagnostics
223       doc = n ^. params . uri
224   modify (\s ->
225     let newDiags = Map.insert doc diags (curDiagnostics s)
226       in s { curDiagnostics = newDiags })
227
228 updateState (ReqApplyWorkspaceEdit r) = do
229
230   allChangeParams <- case r ^. params . edit . documentChanges of
231     Just (List cs) -> do
232       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
233       return $ map getParams cs
234     Nothing -> case r ^. params . edit . changes of
235       Just cs -> do
236         mapM_ checkIfNeedsOpened (HashMap.keys cs)
237         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
238       Nothing -> error "No changes!"
239
240   modifyM $ \s -> do
241     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
242     return $ s { vfs = newVFS }
243
244   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
245       mergedParams = map mergeParams groupedParams
246
247   -- TODO: Don't do this when replaying a session
248   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
249
250   -- Update VFS to new document versions
251   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
252       latestVersions = map ((^. textDocument) . last) sortedVersions
253       bumpedVersions = map (version . _Just +~ 1) latestVersions
254
255   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
256     modify $ \s ->
257       let oldVFS = vfs s
258           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
259           newVFS = Map.adjust update uri oldVFS
260       in s { vfs = newVFS }
261
262   where checkIfNeedsOpened uri = do
263           oldVFS <- vfs <$> get
264           ctx <- ask
265
266           -- if its not open, open it
267           unless (uri `Map.member` oldVFS) $ do
268             let fp = fromJust $ uriToFilePath uri
269             contents <- liftIO $ T.readFile fp
270             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
271                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
272             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
273
274             modifyM $ \s -> do
275               newVFS <- liftIO $ openVFS (vfs s) msg
276               return $ s { vfs = newVFS }
277
278         getParams (TextDocumentEdit docId (List edits)) =
279           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
280             in DidChangeTextDocumentParams docId (List changeEvents)
281
282         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
283
284         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
285
286         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
287
288         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
289         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
290                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
291 updateState _ = return ()
292
293 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
294 sendMessage msg = do
295   h <- serverIn <$> ask
296   logMsg LogClient msg
297   liftIO $ B.hPut h (addHeader $ encode msg)
298
299 -- | Execute a block f that will throw a 'TimeoutException'
300 -- after duration seconds. This will override the global timeout
301 -- for waiting for messages to arrive defined in 'SessionConfig'.
302 withTimeout :: Int -> Session a -> Session a
303 withTimeout duration f = do
304   chan <- asks messageChan
305   timeoutId <- curTimeoutId <$> get
306   modify $ \s -> s { overridingTimeout = True }
307   liftIO $ forkIO $ do
308     threadDelay (duration * 1000000)
309     writeChan chan (TimeoutMessage timeoutId)
310   res <- f
311   modify $ \s -> s { curTimeoutId = timeoutId + 1,
312                      overridingTimeout = False
313                    }
314   return res
315
316 data LogMsgType = LogServer | LogClient
317   deriving Eq
318
319 -- | Logs the message if the config specified it
320 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
321        => LogMsgType -> a -> m ()
322 logMsg t msg = do
323   shouldLog <- asks $ logMessages . config
324   shouldColor <- asks $ logColor . config
325   liftIO $ when shouldLog $ do
326     when shouldColor $ setSGR [SetColor Foreground Dull color]
327     putStrLn $ arrow ++ showPretty msg
328     when shouldColor $ setSGR [Reset]
329
330   where arrow
331           | t == LogServer  = "<-- "
332           | otherwise       = "--> "
333         color
334           | t == LogServer  = Magenta
335           | otherwise       = Cyan
336
337         showPretty = B.unpack . encodePretty