07b33e888c5c3eb44f931ce138e87c2569acdcea
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , logMsg
27   , LogMsgType(..)
28   )
29
30 where
31
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
36 import Control.Monad
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
41 #endif
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
51 import Data.Default
52 import Data.Foldable
53 import Data.List
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
58 import Data.Maybe
59 import Data.Function
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
70 import System.IO
71 import System.Process (ProcessHandle())
72 import System.Timeout
73 import System.IO.Temp
74
75 -- | A session representing one instance of launching and connecting to a server.
76 --
77 -- You can send and receive messages to the server within 'Session' via
78 -- 'Language.Haskell.LSP.Test.message',
79 -- 'Language.Haskell.LSP.Test.sendRequest' and
80 -- 'Language.Haskell.LSP.Test.sendNotification'.
81
82 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
83   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
84
85 #if __GLASGOW_HASKELL__ >= 806
86 instance MonadFail Session where
87   fail s = do
88     lastMsg <- fromJust . lastReceivedMessage <$> get
89     liftIO $ throw (UnexpectedMessage s lastMsg)
90 #endif
91
92 -- | Stuff you can configure for a 'Session'.
93 data SessionConfig = SessionConfig
94   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
95   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
96   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
97   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
98   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
99   }
100
101 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
102 defaultConfig :: SessionConfig
103 defaultConfig = SessionConfig 60 False False True Nothing
104
105 instance Default SessionConfig where
106   def = defaultConfig
107
108 data SessionMessage = ServerMessage FromServerMessage
109                     | TimeoutMessage Int
110   deriving Show
111
112 data SessionContext = SessionContext
113   {
114     serverIn :: Handle
115   , rootDir :: FilePath
116   , messageChan :: Chan SessionMessage
117   , requestMap :: MVar RequestMap
118   , initRsp :: MVar InitializeResponse
119   , config :: SessionConfig
120   , sessionCapabilities :: ClientCapabilities
121   }
122
123 class Monad m => HasReader r m where
124   ask :: m r
125   asks :: (r -> b) -> m b
126   asks f = f <$> ask
127
128 instance HasReader SessionContext Session where
129   ask  = Session (lift $ lift Reader.ask)
130
131 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
132   ask = lift $ lift Reader.ask
133
134 data SessionState = SessionState
135   {
136     curReqId :: LspId
137   , vfs :: VFS
138   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
139   , curTimeoutId :: Int
140   , overridingTimeout :: Bool
141   -- ^ The last received message from the server.
142   -- Used for providing exception information
143   , lastReceivedMessage :: Maybe FromServerMessage
144   }
145
146 class Monad m => HasState s m where
147   get :: m s
148
149   put :: s -> m ()
150
151   modify :: (s -> s) -> m ()
152   modify f = get >>= put . f
153
154   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
155   modifyM f = get >>= f >>= put
156
157 instance HasState SessionState Session where
158   get = Session (lift State.get)
159   put = Session . lift . State.put
160
161 instance Monad m => HasState s (ConduitM a b (StateT s m))
162  where
163   get = lift State.get
164   put = lift . State.put
165
166 instance Monad m => HasState s (ConduitParser a (StateT s m))
167  where
168   get = lift State.get
169   put = lift . State.put
170
171 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
172 runSession context state (Session session) = runReaderT (runStateT conduit state) context
173   where
174     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
175
176     handler (Unexpected "ConduitParser.empty") = do
177       lastMsg <- fromJust . lastReceivedMessage <$> get
178       name <- getParserName
179       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
180
181     handler e = throw e
182
183     chanSource = do
184       msg <- liftIO $ readChan (messageChan context)
185       yield msg
186       chanSource
187
188     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
189     watchdog = Conduit.awaitForever $ \msg -> do
190       curId <- curTimeoutId <$> get
191       case msg of
192         ServerMessage sMsg -> yield sMsg
193         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
194
195 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
196 -- It also does not automatically send initialize and exit messages.
197 runSessionWithHandles :: Handle -- ^ Server in
198                       -> Handle -- ^ Server out
199                       -> ProcessHandle -- ^ Server process
200                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
201                       -> SessionConfig
202                       -> ClientCapabilities
203                       -> FilePath -- ^ Root directory
204                       -> Session () -- ^ To exit the Server properly
205                       -> Session a
206                       -> IO a
207 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
208   absRootDir <- canonicalizePath rootDir
209
210   hSetBuffering serverIn  NoBuffering
211   hSetBuffering serverOut NoBuffering
212   -- This is required to make sure that we don’t get any
213   -- newline conversion or weird encoding issues.
214   hSetBinaryMode serverIn True
215   hSetBinaryMode serverOut True
216
217   reqMap <- newMVar newRequestMap
218   messageChan <- newChan
219   initRsp <- newEmptyMVar
220
221   mainThreadId <- myThreadId
222
223   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
224       initState vfs = SessionState (IdInt 0) vfs
225                                        mempty 0 False Nothing
226       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
227
228       errorHandler = throwTo mainThreadId :: SessionException -> IO()
229       serverListenerLauncher =
230         forkIO $ catch (serverHandler serverOut context) errorHandler
231       server = (Just serverIn, Just serverOut, Nothing, serverProc)
232       serverAndListenerFinalizer tid =
233         finally (timeout (messageTimeout config * 1000000)
234                          (runSession' exitServer))
235                 (cleanupProcess server >> killThread tid)
236
237   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
238                          (const $ runSession' session)
239   return result
240
241 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
242 updateStateC = awaitForever $ \msg -> do
243   updateState msg
244   yield msg
245
246 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
247             => FromServerMessage -> m ()
248 updateState (NotPublishDiagnostics n) = do
249   let List diags = n ^. params . diagnostics
250       doc = n ^. params . uri
251   modify (\s ->
252     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
253       in s { curDiagnostics = newDiags })
254
255 updateState (ReqApplyWorkspaceEdit r) = do
256
257   allChangeParams <- case r ^. params . edit . documentChanges of
258     Just (List cs) -> do
259       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
260       return $ map getParams cs
261     Nothing -> case r ^. params . edit . changes of
262       Just cs -> do
263         mapM_ checkIfNeedsOpened (HashMap.keys cs)
264         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
265       Nothing -> error "No changes!"
266
267   modifyM $ \s -> do
268     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
269     return $ s { vfs = newVFS }
270
271   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
272       mergedParams = map mergeParams groupedParams
273
274   -- TODO: Don't do this when replaying a session
275   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
276
277   -- Update VFS to new document versions
278   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
279       latestVersions = map ((^. textDocument) . last) sortedVersions
280       bumpedVersions = map (version . _Just +~ 1) latestVersions
281
282   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
283     modify $ \s ->
284       let oldVFS = vfs s
285           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
286           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
287       in s { vfs = newVFS }
288
289   where checkIfNeedsOpened uri = do
290           oldVFS <- vfs <$> get
291           ctx <- ask
292
293           -- if its not open, open it
294           unless (toNormalizedUri uri `Map.member` (vfsMap oldVFS)) $ do
295             let fp = fromJust $ uriToFilePath uri
296             contents <- liftIO $ T.readFile fp
297             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
298                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
299             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
300
301             modifyM $ \s -> do
302               let (newVFS,_) = openVFS (vfs s) msg
303               return $ s { vfs = newVFS }
304
305         getParams (TextDocumentEdit docId (List edits)) =
306           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
307             in DidChangeTextDocumentParams docId (List changeEvents)
308
309         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
310
311         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
312
313         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
314
315         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
316         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
317                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
318 updateState _ = return ()
319
320 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
321 sendMessage msg = do
322   h <- serverIn <$> ask
323   logMsg LogClient msg
324   liftIO $ B.hPut h (addHeader $ encode msg)
325
326 -- | Execute a block f that will throw a 'Timeout' exception
327 -- after duration seconds. This will override the global timeout
328 -- for waiting for messages to arrive defined in 'SessionConfig'.
329 withTimeout :: Int -> Session a -> Session a
330 withTimeout duration f = do
331   chan <- asks messageChan
332   timeoutId <- curTimeoutId <$> get
333   modify $ \s -> s { overridingTimeout = True }
334   liftIO $ forkIO $ do
335     threadDelay (duration * 1000000)
336     writeChan chan (TimeoutMessage timeoutId)
337   res <- f
338   modify $ \s -> s { curTimeoutId = timeoutId + 1,
339                      overridingTimeout = False
340                    }
341   return res
342
343 data LogMsgType = LogServer | LogClient
344   deriving Eq
345
346 -- | Logs the message if the config specified it
347 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
348        => LogMsgType -> a -> m ()
349 logMsg t msg = do
350   shouldLog <- asks $ logMessages . config
351   shouldColor <- asks $ logColor . config
352   liftIO $ when shouldLog $ do
353     when shouldColor $ setSGR [SetColor Foreground Dull color]
354     putStrLn $ arrow ++ showPretty msg
355     when shouldColor $ setSGR [Reset]
356
357   where arrow
358           | t == LogServer  = "<-- "
359           | otherwise       = "--> "
360         color
361           | t == LogServer  = Magenta
362           | otherwise       = Cyan
363
364         showPretty = B.unpack . encodePretty
365