1abad54d673728031afabbbd332fcebc396631f8
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , logMsg
27   , LogMsgType(..)
28   )
29
30 where
31
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
36 import Control.Monad
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
41 #endif
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
51 import Data.Default
52 import Data.Foldable
53 import Data.List
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
58 import Data.Maybe
59 import Data.Function
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
70 import System.IO
71 import System.Process (ProcessHandle())
72 import System.Timeout
73
74 -- | A session representing one instance of launching and connecting to a server.
75 --
76 -- You can send and receive messages to the server within 'Session' via
77 -- 'Language.Haskell.LSP.Test.message',
78 -- 'Language.Haskell.LSP.Test.sendRequest' and
79 -- 'Language.Haskell.LSP.Test.sendNotification'.
80
81 -- newtype Session a = Session (ParserStateReader FromServerMessage SessionState SessionContext IO a)
82
83 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
84   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
85
86 #if __GLASGOW_HASKELL__ >= 806
87 instance MonadFail Session where
88   fail s = do
89     lastMsg <- fromJust . lastReceivedMessage <$> get
90     liftIO $ throw (UnexpectedMessage s lastMsg)
91 #endif
92
93 -- | Stuff you can configure for a 'Session'.
94 data SessionConfig = SessionConfig
95   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
96   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
97   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
98   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
99   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
100   }
101
102 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
103 defaultConfig :: SessionConfig
104 defaultConfig = SessionConfig 60 False False True Nothing
105
106 instance Default SessionConfig where
107   def = defaultConfig
108
109 data SessionMessage = ServerMessage FromServerMessage
110                     | TimeoutMessage Int
111   deriving Show
112
113 data SessionContext = SessionContext
114   {
115     serverIn :: Handle
116   , rootDir :: FilePath
117   , messageChan :: Chan SessionMessage
118   , requestMap :: MVar RequestMap
119   , initRsp :: MVar InitializeResponse
120   , config :: SessionConfig
121   , sessionCapabilities :: ClientCapabilities
122   }
123
124 class Monad m => HasReader r m where
125   ask :: m r
126   asks :: (r -> b) -> m b
127   asks f = f <$> ask
128
129 instance HasReader SessionContext Session where
130   ask  = Session (lift $ lift Reader.ask)
131
132 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
133   ask = lift $ lift Reader.ask
134
135 data SessionState = SessionState
136   {
137     curReqId :: LspId
138   , vfs :: VFS
139   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
140   , curTimeoutId :: Int
141   , overridingTimeout :: Bool
142   -- ^ The last received message from the server.
143   -- Used for providing exception information
144   , lastReceivedMessage :: Maybe FromServerMessage
145   }
146
147 class Monad m => HasState s m where
148   get :: m s
149
150   put :: s -> m ()
151
152   modify :: (s -> s) -> m ()
153   modify f = get >>= put . f
154
155   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
156   modifyM f = get >>= f >>= put
157
158 instance HasState SessionState Session where
159   get = Session (lift State.get)
160   put = Session . lift . State.put
161
162 instance Monad m => HasState s (ConduitM a b (StateT s m))
163  where
164   get = lift State.get
165   put = lift . State.put
166
167 instance Monad m => HasState s (ConduitParser a (StateT s m))
168  where
169   get = lift State.get
170   put = lift . State.put
171
172 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
173 runSession context state (Session session) = runReaderT (runStateT conduit state) context
174   where
175     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
176
177     handler (Unexpected "ConduitParser.empty") = do
178       lastMsg <- fromJust . lastReceivedMessage <$> get
179       name <- getParserName
180       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
181
182     handler e = throw e
183
184     chanSource = do
185       msg <- liftIO $ readChan (messageChan context)
186       yield msg
187       chanSource
188
189     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
190     watchdog = Conduit.awaitForever $ \msg -> do
191       curId <- curTimeoutId <$> get
192       case msg of
193         ServerMessage sMsg -> yield sMsg
194         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
195
196 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
197 -- It also does not automatically send initialize and exit messages.
198 runSessionWithHandles :: Handle -- ^ Server in
199                       -> Handle -- ^ Server out
200                       -> ProcessHandle -- ^ Server process
201                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
202                       -> SessionConfig
203                       -> ClientCapabilities
204                       -> FilePath -- ^ Root directory
205                       -> Session () -- ^ To exit the Server properly
206                       -> Session a
207                       -> IO a
208 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
209   absRootDir <- canonicalizePath rootDir
210
211   hSetBuffering serverIn  NoBuffering
212   hSetBuffering serverOut NoBuffering
213   -- This is required to make sure that we don’t get any
214   -- newline conversion or weird encoding issues.
215   hSetBinaryMode serverIn True
216   hSetBinaryMode serverOut True
217
218   reqMap <- newMVar newRequestMap
219   messageChan <- newChan
220   initRsp <- newEmptyMVar
221
222   mainThreadId <- myThreadId
223
224   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
225       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
226       runSession' = runSession context initState
227
228       errorHandler = throwTo mainThreadId :: SessionException -> IO()
229       serverListenerLauncher =
230         forkIO $ catch (serverHandler serverOut context) errorHandler
231       server = (Just serverIn, Just serverOut, Nothing, serverProc)
232       serverAndListenerFinalizer tid =
233         finally (timeout (messageTimeout config * 1000000)
234                          (runSession' exitServer))
235                 (cleanupProcess server >> killThread tid)
236
237   (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
238                          (const $ runSession' session)
239   return result
240
241 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
242 updateStateC = awaitForever $ \msg -> do
243   updateState msg
244   yield msg
245
246 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
247             => FromServerMessage -> m ()
248 updateState (NotPublishDiagnostics n) = do
249   let List diags = n ^. params . diagnostics
250       doc = n ^. params . uri
251   modify (\s ->
252     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
253       in s { curDiagnostics = newDiags })
254
255 updateState (ReqApplyWorkspaceEdit r) = do
256
257   allChangeParams <- case r ^. params . edit . documentChanges of
258     Just (List cs) -> do
259       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
260       return $ map getParams cs
261     Nothing -> case r ^. params . edit . changes of
262       Just cs -> do
263         mapM_ checkIfNeedsOpened (HashMap.keys cs)
264         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
265       Nothing -> error "No changes!"
266
267   modifyM $ \s -> do
268     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
269     return $ s { vfs = newVFS }
270
271   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
272       mergedParams = map mergeParams groupedParams
273
274   -- TODO: Don't do this when replaying a session
275   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
276
277   -- Update VFS to new document versions
278   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
279       latestVersions = map ((^. textDocument) . last) sortedVersions
280       bumpedVersions = map (version . _Just +~ 1) latestVersions
281
282   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
283     modify $ \s ->
284       let oldVFS = vfs s
285           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
286           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
287       in s { vfs = newVFS }
288
289   where checkIfNeedsOpened uri = do
290           oldVFS <- vfs <$> get
291           ctx <- ask
292
293           -- if its not open, open it
294           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
295             let fp = fromJust $ uriToFilePath uri
296             contents <- liftIO $ T.readFile fp
297             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
298                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
299             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
300
301             modifyM $ \s -> do
302               newVFS <- liftIO $ openVFS (vfs s) msg
303               return $ s { vfs = newVFS }
304
305         getParams (TextDocumentEdit docId (List edits)) =
306           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
307             in DidChangeTextDocumentParams docId (List changeEvents)
308
309         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
310
311         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
312
313         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
314
315         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
316         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
317                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
318 updateState _ = return ()
319
320 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
321 sendMessage msg = do
322   h <- serverIn <$> ask
323   logMsg LogClient msg
324   liftIO $ B.hPut h (addHeader $ encode msg)
325
326 -- | Execute a block f that will throw a 'Timeout' exception
327 -- after duration seconds. This will override the global timeout
328 -- for waiting for messages to arrive defined in 'SessionConfig'.
329 withTimeout :: Int -> Session a -> Session a
330 withTimeout duration f = do
331   chan <- asks messageChan
332   timeoutId <- curTimeoutId <$> get
333   modify $ \s -> s { overridingTimeout = True }
334   liftIO $ forkIO $ do
335     threadDelay (duration * 1000000)
336     writeChan chan (TimeoutMessage timeoutId)
337   res <- f
338   modify $ \s -> s { curTimeoutId = timeoutId + 1,
339                      overridingTimeout = False
340                    }
341   return res
342
343 data LogMsgType = LogServer | LogClient
344   deriving Eq
345
346 -- | Logs the message if the config specified it
347 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
348        => LogMsgType -> a -> m ()
349 logMsg t msg = do
350   shouldLog <- asks $ logMessages . config
351   shouldColor <- asks $ logColor . config
352   liftIO $ when shouldLog $ do
353     when shouldColor $ setSGR [SetColor Foreground Dull color]
354     putStrLn $ arrow ++ showPretty msg
355     when shouldColor $ setSGR [Reset]
356
357   where arrow
358           | t == LogServer  = "<-- "
359           | otherwise       = "--> "
360         color
361           | t == LogServer  = Magenta
362           | otherwise       = Cyan
363
364         showPretty = B.unpack . encodePretty
365