5948ed0a818c1aed1b28cbef5b5f463d1b952ce7
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE FlexibleInstances #-}
3 {-# LANGUAGE MultiParamTypeClasses #-}
4 {-# LANGUAGE FlexibleContexts #-}
5
6 module Language.Haskell.LSP.Test.Session
7   ( Session
8   , SessionConfig(..)
9   , defaultConfig
10   , SessionMessage(..)
11   , SessionContext(..)
12   , SessionState(..)
13   , runSessionWithHandles
14   , get
15   , put
16   , modify
17   , modifyM
18   , ask
19   , asks
20   , sendMessage
21   , updateState
22   , withTimeout
23   , logMsg
24   , LogMsgType(..)
25   )
26
27 where
28
29 import Control.Concurrent hiding (yield)
30 import Control.Exception
31 import Control.Lens hiding (List)
32 import Control.Monad
33 import Control.Monad.IO.Class
34 import Control.Monad.Except
35 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
36 import qualified Control.Monad.Trans.Reader as Reader (ask)
37 import Control.Monad.Trans.State (StateT, runStateT)
38 import qualified Control.Monad.Trans.State as State (get, put)
39 import qualified Data.ByteString.Lazy.Char8 as B
40 import Data.Aeson
41 import Data.Aeson.Encode.Pretty
42 import Data.Conduit as Conduit
43 import Data.Conduit.Parser as Parser
44 import Data.Default
45 import Data.Foldable
46 import Data.List
47 import qualified Data.Map as Map
48 import qualified Data.Text as T
49 import qualified Data.Text.IO as T
50 import qualified Data.HashMap.Strict as HashMap
51 import Data.Maybe
52 import Data.Function
53 import Language.Haskell.LSP.Messages
54 import Language.Haskell.LSP.Types.Capabilities
55 import Language.Haskell.LSP.Types
56 import Language.Haskell.LSP.Types.Lens hiding (error)
57 import Language.Haskell.LSP.VFS
58 import Language.Haskell.LSP.Test.Decoding
59 import Language.Haskell.LSP.Test.Exceptions
60 import System.Console.ANSI
61 import System.Directory
62 import System.IO
63
64 -- | A session representing one instance of launching and connecting to a server.
65 -- 
66 -- You can send and receive messages to the server within 'Session' via 'getMessage',
67 -- 'sendRequest' and 'sendNotification'.
68 --
69
70 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
71
72 -- | Stuff you can configure for a 'Session'.
73 data SessionConfig = SessionConfig
74   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
75   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
76   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
77   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
78   }
79
80 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
81 defaultConfig :: SessionConfig
82 defaultConfig = SessionConfig 60 False False True
83
84 instance Default SessionConfig where
85   def = defaultConfig
86
87 data SessionMessage = ServerMessage FromServerMessage
88                     | TimeoutMessage Int
89   deriving Show
90
91 data SessionContext = SessionContext
92   {
93     serverIn :: Handle
94   , rootDir :: FilePath
95   , messageChan :: Chan SessionMessage
96   , requestMap :: MVar RequestMap
97   , initRsp :: MVar InitializeResponse
98   , config :: SessionConfig
99   , sessionCapabilities :: ClientCapabilities
100   }
101
102 class Monad m => HasReader r m where
103   ask :: m r
104   asks :: (r -> b) -> m b
105   asks f = f <$> ask
106
107 instance Monad m => HasReader r (ParserStateReader a s r m) where
108   ask = lift $ lift Reader.ask
109
110 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
111   ask = lift $ lift Reader.ask
112
113 data SessionState = SessionState
114   {
115     curReqId :: LspId
116   , vfs :: VFS
117   , curDiagnostics :: Map.Map Uri [Diagnostic]
118   , curTimeoutId :: Int
119   , overridingTimeout :: Bool
120   -- ^ The last received message from the server.
121   -- Used for providing exception information
122   , lastReceivedMessage :: Maybe FromServerMessage
123   }
124
125 class Monad m => HasState s m where
126   get :: m s
127
128   put :: s -> m ()
129
130   modify :: (s -> s) -> m ()
131   modify f = get >>= put . f
132
133   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
134   modifyM f = get >>= f >>= put
135
136 instance Monad m => HasState s (ParserStateReader a s r m) where
137   get = lift State.get
138   put = lift . State.put
139
140 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
141  where
142   get = lift State.get
143   put = lift . State.put
144
145 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
146
147 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
148 runSession context state session = runReaderT (runStateT conduit state) context
149   where
150     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
151         
152     handler (Unexpected "ConduitParser.empty") = do
153       lastMsg <- fromJust . lastReceivedMessage <$> get
154       name <- getParserName
155       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
156
157     handler e = throw e
158
159     chanSource = do
160       msg <- liftIO $ readChan (messageChan context)
161       yield msg
162       chanSource
163
164     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
165     watchdog = Conduit.awaitForever $ \msg -> do
166       curId <- curTimeoutId <$> get
167       case msg of
168         ServerMessage sMsg -> yield sMsg
169         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
170
171 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
172 -- It also does not automatically send initialize and exit messages.
173 runSessionWithHandles :: Handle -- ^ Server in
174                       -> Handle -- ^ Server out
175                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
176                       -> SessionConfig
177                       -> ClientCapabilities
178                       -> FilePath -- ^ Root directory
179                       -> Session a
180                       -> IO a
181 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
182   absRootDir <- canonicalizePath rootDir
183
184   hSetBuffering serverIn  NoBuffering
185   hSetBuffering serverOut NoBuffering
186
187   reqMap <- newMVar newRequestMap
188   messageChan <- newChan
189   initRsp <- newEmptyMVar
190
191   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
192       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
193
194   threadId <- forkIO $ void $ serverHandler serverOut context
195   (result, _) <- runSession context initState session
196
197   killThread threadId
198
199   return result
200
201 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
202 updateStateC = awaitForever $ \msg -> do
203   updateState msg
204   yield msg
205
206 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
207 updateState (NotPublishDiagnostics n) = do
208   let List diags = n ^. params . diagnostics
209       doc = n ^. params . uri
210   modify (\s ->
211     let newDiags = Map.insert doc diags (curDiagnostics s)
212       in s { curDiagnostics = newDiags })
213
214 updateState (ReqApplyWorkspaceEdit r) = do
215
216   allChangeParams <- case r ^. params . edit . documentChanges of
217     Just (List cs) -> do
218       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
219       return $ map getParams cs
220     Nothing -> case r ^. params . edit . changes of
221       Just cs -> do
222         mapM_ checkIfNeedsOpened (HashMap.keys cs)
223         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
224       Nothing -> error "No changes!"
225
226   modifyM $ \s -> do
227     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
228     return $ s { vfs = newVFS }
229
230   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
231       mergedParams = map mergeParams groupedParams
232
233   -- TODO: Don't do this when replaying a session
234   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
235
236   -- Update VFS to new document versions
237   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
238       latestVersions = map ((^. textDocument) . last) sortedVersions
239       bumpedVersions = map (version . _Just +~ 1) latestVersions
240
241   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
242     modify $ \s ->
243       let oldVFS = vfs s
244           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
245           newVFS = Map.adjust update uri oldVFS
246       in s { vfs = newVFS }
247
248   where checkIfNeedsOpened uri = do
249           oldVFS <- vfs <$> get
250           ctx <- ask
251
252           -- if its not open, open it
253           unless (uri `Map.member` oldVFS) $ do
254             let fp = fromJust $ uriToFilePath uri
255             contents <- liftIO $ T.readFile fp
256             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
257                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
258             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
259
260             modifyM $ \s -> do 
261               newVFS <- liftIO $ openVFS (vfs s) msg
262               return $ s { vfs = newVFS }
263
264         getParams (TextDocumentEdit docId (List edits)) =
265           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
266             in DidChangeTextDocumentParams docId (List changeEvents)
267
268         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
269
270         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
271
272         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
273
274         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
275         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
276                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
277 updateState _ = return ()
278
279 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
280 sendMessage msg = do
281   h <- serverIn <$> ask
282   logMsg LogClient msg
283   liftIO $ B.hPut h (addHeader $ encode msg)
284
285 -- | Execute a block f that will throw a 'TimeoutException'
286 -- after duration seconds. This will override the global timeout
287 -- for waiting for messages to arrive defined in 'SessionConfig'.
288 withTimeout :: Int -> Session a -> Session a
289 withTimeout duration f = do
290   chan <- asks messageChan
291   timeoutId <- curTimeoutId <$> get 
292   modify $ \s -> s { overridingTimeout = True }
293   liftIO $ forkIO $ do
294     threadDelay (duration * 1000000)
295     writeChan chan (TimeoutMessage timeoutId)
296   res <- f
297   modify $ \s -> s { curTimeoutId = timeoutId + 1,
298                      overridingTimeout = False 
299                    }
300   return res
301
302 data LogMsgType = LogServer | LogClient
303   deriving Eq
304
305 -- | Logs the message if the config specified it
306 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
307        => LogMsgType -> a -> m ()
308 logMsg t msg = do
309   shouldLog <- asks $ logMessages . config
310   shouldColor <- asks $ logColor . config
311   liftIO $ when shouldLog $ do
312     when shouldColor $ setSGR [SetColor Foreground Dull color]
313     putStrLn $ arrow ++ showPretty msg
314     when shouldColor $ setSGR [Reset]
315
316   where arrow
317           | t == LogServer  = "<-- "
318           | otherwise       = "--> "
319         color
320           | t == LogServer  = Magenta
321           | otherwise       = Cyan
322   
323         showPretty = B.unpack . encodePretty