2df5ab3c4263e725cc240b22933480943071ac73
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68
69 -- | A session representing one instance of launching and connecting to a server.
70 -- 
71 -- You can send and receive messages to the server within 'Session' via 'getMessage',
72 -- 'sendRequest' and 'sendNotification'.
73 --
74
75 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
76
77 #if __GLASGOW_HASKELL__ >= 806
78 instance MonadFail Session where
79   fail s = do
80     lastMsg <- fromJust . lastReceivedMessage <$> get
81     liftIO $ throw (UnexpectedMessage s lastMsg)
82 #endif
83
84 -- | Stuff you can configure for a 'Session'.
85 data SessionConfig = SessionConfig
86   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
87   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
88   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
89   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
90   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
91   }
92
93 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
94 defaultConfig :: SessionConfig
95 defaultConfig = SessionConfig 60 False True True Nothing
96
97 instance Default SessionConfig where
98   def = defaultConfig
99
100 data SessionMessage = ServerMessage FromServerMessage
101                     | TimeoutMessage Int
102   deriving Show
103
104 data SessionContext = SessionContext
105   {
106     serverIn :: Handle
107   , rootDir :: FilePath
108   , messageChan :: Chan SessionMessage
109   , requestMap :: MVar RequestMap
110   , initRsp :: MVar InitializeResponse
111   , config :: SessionConfig
112   , sessionCapabilities :: ClientCapabilities
113   }
114
115 class Monad m => HasReader r m where
116   ask :: m r
117   asks :: (r -> b) -> m b
118   asks f = f <$> ask
119
120 instance Monad m => HasReader r (ParserStateReader a s r m) where
121   ask = lift $ lift Reader.ask
122
123 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
124   ask = lift $ lift Reader.ask
125
126 data SessionState = SessionState
127   {
128     curReqId :: LspId
129   , vfs :: VFS
130   , curDiagnostics :: Map.Map Uri [Diagnostic]
131   , curTimeoutId :: Int
132   , overridingTimeout :: Bool
133   -- ^ The last received message from the server.
134   -- Used for providing exception information
135   , lastReceivedMessage :: Maybe FromServerMessage
136   }
137
138 class Monad m => HasState s m where
139   get :: m s
140
141   put :: s -> m ()
142
143   modify :: (s -> s) -> m ()
144   modify f = get >>= put . f
145
146   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
147   modifyM f = get >>= f >>= put
148
149 instance Monad m => HasState s (ParserStateReader a s r m) where
150   get = lift State.get
151   put = lift . State.put
152
153 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
154  where
155   get = lift State.get
156   put = lift . State.put
157
158 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
159
160 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
161 runSession context state session = runReaderT (runStateT conduit state) context
162   where
163     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
164
165     handler (Unexpected "ConduitParser.empty") = do
166       lastMsg <- fromJust . lastReceivedMessage <$> get
167       name <- getParserName
168       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
169
170     handler e = throw e
171
172     chanSource = do
173       msg <- liftIO $ readChan (messageChan context)
174       yield msg
175       chanSource
176
177     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
178     watchdog = Conduit.awaitForever $ \msg -> do
179       curId <- curTimeoutId <$> get
180       case msg of
181         ServerMessage sMsg -> yield sMsg
182         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
183
184 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
185 -- It also does not automatically send initialize and exit messages.
186 runSessionWithHandles :: Handle -- ^ Server in
187                       -> Handle -- ^ Server out
188                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
189                       -> SessionConfig
190                       -> ClientCapabilities
191                       -> FilePath -- ^ Root directory
192                       -> Session a
193                       -> IO a
194 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
195   absRootDir <- canonicalizePath rootDir
196
197   hSetBuffering serverIn  NoBuffering
198   hSetBuffering serverOut NoBuffering
199
200   reqMap <- newMVar newRequestMap
201   messageChan <- newChan
202   initRsp <- newEmptyMVar
203
204   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
205       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
206
207   threadId <- forkIO $ void $ serverHandler serverOut context
208   (result, _) <- runSession context initState session
209
210   killThread threadId
211
212   return result
213
214 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
215 updateStateC = awaitForever $ \msg -> do
216   updateState msg
217   yield msg
218
219 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
220 updateState (NotPublishDiagnostics n) = do
221   let List diags = n ^. params . diagnostics
222       doc = n ^. params . uri
223   modify (\s ->
224     let newDiags = Map.insert doc diags (curDiagnostics s)
225       in s { curDiagnostics = newDiags })
226
227 updateState (ReqApplyWorkspaceEdit r) = do
228
229   allChangeParams <- case r ^. params . edit . documentChanges of
230     Just (List cs) -> do
231       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
232       return $ map getParams cs
233     Nothing -> case r ^. params . edit . changes of
234       Just cs -> do
235         mapM_ checkIfNeedsOpened (HashMap.keys cs)
236         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
237       Nothing -> error "No changes!"
238
239   modifyM $ \s -> do
240     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
241     return $ s { vfs = newVFS }
242
243   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
244       mergedParams = map mergeParams groupedParams
245
246   -- TODO: Don't do this when replaying a session
247   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
248
249   -- Update VFS to new document versions
250   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
251       latestVersions = map ((^. textDocument) . last) sortedVersions
252       bumpedVersions = map (version . _Just +~ 1) latestVersions
253
254   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
255     modify $ \s ->
256       let oldVFS = vfs s
257           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
258           newVFS = Map.adjust update uri oldVFS
259       in s { vfs = newVFS }
260
261   where checkIfNeedsOpened uri = do
262           oldVFS <- vfs <$> get
263           ctx <- ask
264
265           -- if its not open, open it
266           unless (uri `Map.member` oldVFS) $ do
267             let fp = fromJust $ uriToFilePath uri
268             contents <- liftIO $ T.readFile fp
269             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
270                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
271             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
272
273             modifyM $ \s -> do
274               newVFS <- liftIO $ openVFS (vfs s) msg
275               return $ s { vfs = newVFS }
276
277         getParams (TextDocumentEdit docId (List edits)) =
278           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
279             in DidChangeTextDocumentParams docId (List changeEvents)
280
281         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
282
283         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
284
285         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
286
287         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
288         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
289                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
290 updateState _ = return ()
291
292 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
293 sendMessage msg = do
294   h <- serverIn <$> ask
295   logMsg LogClient msg
296   liftIO $ B.hPut h (addHeader $ encode msg)
297
298 -- | Execute a block f that will throw a 'TimeoutException'
299 -- after duration seconds. This will override the global timeout
300 -- for waiting for messages to arrive defined in 'SessionConfig'.
301 withTimeout :: Int -> Session a -> Session a
302 withTimeout duration f = do
303   chan <- asks messageChan
304   timeoutId <- curTimeoutId <$> get
305   modify $ \s -> s { overridingTimeout = True }
306   liftIO $ forkIO $ do
307     threadDelay (duration * 1000000)
308     writeChan chan (TimeoutMessage timeoutId)
309   res <- f
310   modify $ \s -> s { curTimeoutId = timeoutId + 1,
311                      overridingTimeout = False
312                    }
313   return res
314
315 data LogMsgType = LogServer | LogClient
316   deriving Eq
317
318 -- | Logs the message if the config specified it
319 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
320        => LogMsgType -> a -> m ()
321 logMsg t msg = do
322   shouldLog <- asks $ logMessages . config
323   shouldColor <- asks $ logColor . config
324   liftIO $ when shouldLog $ do
325     when shouldColor $ setSGR [SetColor Foreground Dull color]
326     putStrLn $ arrow ++ showPretty msg
327     when shouldColor $ setSGR [Reset]
328
329   where arrow
330           | t == LogServer  = "<-- "
331           | otherwise       = "--> "
332         color
333           | t == LogServer  = Magenta
334           | otherwise       = Cyan
335
336         showPretty = B.unpack . encodePretty