bbfdf386ac167bd0f8ab5b9a277b754920aacf57
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Compat
64 import Language.Haskell.LSP.Test.Decoding
65 import Language.Haskell.LSP.Test.Exceptions
66 import System.Console.ANSI
67 import System.Directory
68 import System.IO
69 import System.Process (ProcessHandle())
70 import System.Timeout
71
72 -- | A session representing one instance of launching and connecting to a server.
73 --
74 -- You can send and receive messages to the server within 'Session' via
75 -- 'Language.Haskell.LSP.Test.message',
76 -- 'Language.Haskell.LSP.Test.sendRequest' and
77 -- 'Language.Haskell.LSP.Test.sendNotification'.
78
79 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
80
81 #if __GLASGOW_HASKELL__ >= 806
82 instance MonadFail Session where
83   fail s = do
84     lastMsg <- fromJust . lastReceivedMessage <$> get
85     liftIO $ throw (UnexpectedMessage s lastMsg)
86 #endif
87
88 -- | Stuff you can configure for a 'Session'.
89 data SessionConfig = SessionConfig
90   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
91   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
92   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
93   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
94   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
95   }
96
97 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
98 defaultConfig :: SessionConfig
99 defaultConfig = SessionConfig 60 False False True Nothing
100
101 instance Default SessionConfig where
102   def = defaultConfig
103
104 data SessionMessage = ServerMessage FromServerMessage
105                     | TimeoutMessage Int
106   deriving Show
107
108 data SessionContext = SessionContext
109   {
110     serverIn :: Handle
111   , rootDir :: FilePath
112   , messageChan :: Chan SessionMessage
113   , requestMap :: MVar RequestMap
114   , initRsp :: MVar InitializeResponse
115   , config :: SessionConfig
116   , sessionCapabilities :: ClientCapabilities
117   }
118
119 class Monad m => HasReader r m where
120   ask :: m r
121   asks :: (r -> b) -> m b
122   asks f = f <$> ask
123
124 instance Monad m => HasReader r (ParserStateReader a s r m) where
125   ask = lift $ lift Reader.ask
126
127 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
128   ask = lift $ lift Reader.ask
129
130 data SessionState = SessionState
131   {
132     curReqId :: LspId
133   , vfs :: VFS
134   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
135   , curTimeoutId :: Int
136   , overridingTimeout :: Bool
137   -- ^ The last received message from the server.
138   -- Used for providing exception information
139   , lastReceivedMessage :: Maybe FromServerMessage
140   }
141
142 class Monad m => HasState s m where
143   get :: m s
144
145   put :: s -> m ()
146
147   modify :: (s -> s) -> m ()
148   modify f = get >>= put . f
149
150   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
151   modifyM f = get >>= f >>= put
152
153 instance Monad m => HasState s (ParserStateReader a s r m) where
154   get = lift State.get
155   put = lift . State.put
156
157 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
158  where
159   get = lift State.get
160   put = lift . State.put
161
162 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
163
164 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
165 runSession context state session = runReaderT (runStateT conduit state) context
166   where
167     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
168
169     handler (Unexpected "ConduitParser.empty") = do
170       lastMsg <- fromJust . lastReceivedMessage <$> get
171       name <- getParserName
172       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
173
174     handler e = throw e
175
176     chanSource = do
177       msg <- liftIO $ readChan (messageChan context)
178       yield msg
179       chanSource
180
181     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
182     watchdog = Conduit.awaitForever $ \msg -> do
183       curId <- curTimeoutId <$> get
184       case msg of
185         ServerMessage sMsg -> yield sMsg
186         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
187
188 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
189 -- It also does not automatically send initialize and exit messages.
190 runSessionWithHandles :: Handle -- ^ Server in
191                       -> Handle -- ^ Server out
192                       -> ProcessHandle -- ^ Server process
193                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
194                       -> SessionConfig
195                       -> ClientCapabilities
196                       -> FilePath -- ^ Root directory
197                       -> Session () -- ^ To exit the Server properly
198                       -> Session a
199                       -> IO a
200 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
201   
202   absRootDir <- canonicalizePath rootDir
203
204   hSetBuffering serverIn  NoBuffering
205   hSetBuffering serverOut NoBuffering
206   -- This is required to make sure that we don’t get any
207   -- newline conversion or weird encoding issues.
208   hSetBinaryMode serverIn True
209   hSetBinaryMode serverOut True
210
211   reqMap <- newMVar newRequestMap
212   messageChan <- newChan
213   initRsp <- newEmptyMVar
214
215   mainThreadId <- myThreadId
216
217   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
218       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
219       runSession' = runSession context initState
220       
221       errorHandler = throwTo mainThreadId :: SessionException -> IO()
222       serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
223       server = (Just serverIn, Just serverOut, Nothing, serverProc)
224       serverFinalizer tid = finally (timeout (messageTimeout config * 1000000)
225                                              (runSession' exitServer))
226                                     (cleanupRunningProcess server >> killThread tid)
227       
228   (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
229   return result
230
231 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
232 updateStateC = awaitForever $ \msg -> do
233   updateState msg
234   yield msg
235
236 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
237 updateState (NotPublishDiagnostics n) = do
238   let List diags = n ^. params . diagnostics
239       doc = n ^. params . uri
240   modify (\s ->
241     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
242       in s { curDiagnostics = newDiags })
243
244 updateState (ReqApplyWorkspaceEdit r) = do
245
246   allChangeParams <- case r ^. params . edit . documentChanges of
247     Just (List cs) -> do
248       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
249       return $ map getParams cs
250     Nothing -> case r ^. params . edit . changes of
251       Just cs -> do
252         mapM_ checkIfNeedsOpened (HashMap.keys cs)
253         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
254       Nothing -> error "No changes!"
255
256   modifyM $ \s -> do
257     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
258     return $ s { vfs = newVFS }
259
260   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
261       mergedParams = map mergeParams groupedParams
262
263   -- TODO: Don't do this when replaying a session
264   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
265
266   -- Update VFS to new document versions
267   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
268       latestVersions = map ((^. textDocument) . last) sortedVersions
269       bumpedVersions = map (version . _Just +~ 1) latestVersions
270
271   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
272     modify $ \s ->
273       let oldVFS = vfs s
274           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
275           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
276       in s { vfs = newVFS }
277
278   where checkIfNeedsOpened uri = do
279           oldVFS <- vfs <$> get
280           ctx <- ask
281
282           -- if its not open, open it
283           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
284             let fp = fromJust $ uriToFilePath uri
285             contents <- liftIO $ T.readFile fp
286             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
287                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
288             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
289
290             modifyM $ \s -> do
291               newVFS <- liftIO $ openVFS (vfs s) msg
292               return $ s { vfs = newVFS }
293
294         getParams (TextDocumentEdit docId (List edits)) =
295           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
296             in DidChangeTextDocumentParams docId (List changeEvents)
297
298         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
299
300         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
301
302         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
303
304         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
305         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
306                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
307 updateState _ = return ()
308
309 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
310 sendMessage msg = do
311   h <- serverIn <$> ask
312   logMsg LogClient msg
313   liftIO $ B.hPut h (addHeader $ encode msg)
314
315 -- | Execute a block f that will throw a 'Timeout' exception
316 -- after duration seconds. This will override the global timeout
317 -- for waiting for messages to arrive defined in 'SessionConfig'.
318 withTimeout :: Int -> Session a -> Session a
319 withTimeout duration f = do
320   chan <- asks messageChan
321   timeoutId <- curTimeoutId <$> get
322   modify $ \s -> s { overridingTimeout = True }
323   liftIO $ forkIO $ do
324     threadDelay (duration * 1000000)
325     writeChan chan (TimeoutMessage timeoutId)
326   res <- f
327   modify $ \s -> s { curTimeoutId = timeoutId + 1,
328                      overridingTimeout = False
329                    }
330   return res
331
332 data LogMsgType = LogServer | LogClient
333   deriving Eq
334
335 -- | Logs the message if the config specified it
336 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
337        => LogMsgType -> a -> m ()
338 logMsg t msg = do
339   shouldLog <- asks $ logMessages . config
340   shouldColor <- asks $ logColor . config
341   liftIO $ when shouldLog $ do
342     when shouldColor $ setSGR [SetColor Foreground Dull color]
343     putStrLn $ arrow ++ showPretty msg
344     when shouldColor $ setSGR [Reset]
345
346   where arrow
347           | t == LogServer  = "<-- "
348           | otherwise       = "--> "
349         color
350           | t == LogServer  = Magenta
351           | otherwise       = Cyan
352
353         showPretty = B.unpack . encodePretty
354