Upgrading to haskell-lsp 0.15
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
7
8 module Language.Haskell.LSP.Test.Session
9   ( Session
10   , SessionConfig(..)
11   , defaultConfig
12   , SessionMessage(..)
13   , SessionContext(..)
14   , SessionState(..)
15   , runSessionWithHandles
16   , get
17   , put
18   , modify
19   , modifyM
20   , ask
21   , asks
22   , sendMessage
23   , updateState
24   , withTimeout
25   , logMsg
26   , LogMsgType(..)
27   )
28
29 where
30
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
34 import Control.Monad
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
39 #endif
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
45 import Data.Aeson
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
49 import Data.Default
50 import Data.Foldable
51 import Data.List
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
56 import Data.Maybe
57 import Data.Function
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
67 import System.IO
68
69 -- | A session representing one instance of launching and connecting to a server.
70 --
71 -- You can send and receive messages to the server within 'Session' via
72 -- 'Language.Haskell.LSP.Test.message',
73 -- 'Language.Haskell.LSP.Test.sendRequest' and
74 -- 'Language.Haskell.LSP.Test.sendNotification'.
75
76 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
77
78 #if __GLASGOW_HASKELL__ >= 806
79 instance MonadFail Session where
80   fail s = do
81     lastMsg <- fromJust . lastReceivedMessage <$> get
82     liftIO $ throw (UnexpectedMessage s lastMsg)
83 #endif
84
85 -- | Stuff you can configure for a 'Session'.
86 data SessionConfig = SessionConfig
87   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
88   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
89   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
90   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
91   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
92   }
93
94 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
95 defaultConfig :: SessionConfig
96 defaultConfig = SessionConfig 60 False False True Nothing
97
98 instance Default SessionConfig where
99   def = defaultConfig
100
101 data SessionMessage = ServerMessage FromServerMessage
102                     | TimeoutMessage Int
103   deriving Show
104
105 data SessionContext = SessionContext
106   {
107     serverIn :: Handle
108   , rootDir :: FilePath
109   , messageChan :: Chan SessionMessage
110   , requestMap :: MVar RequestMap
111   , initRsp :: MVar InitializeResponse
112   , config :: SessionConfig
113   , sessionCapabilities :: ClientCapabilities
114   }
115
116 class Monad m => HasReader r m where
117   ask :: m r
118   asks :: (r -> b) -> m b
119   asks f = f <$> ask
120
121 instance Monad m => HasReader r (ParserStateReader a s r m) where
122   ask = lift $ lift Reader.ask
123
124 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
125   ask = lift $ lift Reader.ask
126
127 data SessionState = SessionState
128   {
129     curReqId :: LspId
130   , vfs :: VFS
131   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
132   , curTimeoutId :: Int
133   , overridingTimeout :: Bool
134   -- ^ The last received message from the server.
135   -- Used for providing exception information
136   , lastReceivedMessage :: Maybe FromServerMessage
137   }
138
139 class Monad m => HasState s m where
140   get :: m s
141
142   put :: s -> m ()
143
144   modify :: (s -> s) -> m ()
145   modify f = get >>= put . f
146
147   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
148   modifyM f = get >>= f >>= put
149
150 instance Monad m => HasState s (ParserStateReader a s r m) where
151   get = lift State.get
152   put = lift . State.put
153
154 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
155  where
156   get = lift State.get
157   put = lift . State.put
158
159 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
160
161 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
162 runSession context state session = runReaderT (runStateT conduit state) context
163   where
164     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
165
166     handler (Unexpected "ConduitParser.empty") = do
167       lastMsg <- fromJust . lastReceivedMessage <$> get
168       name <- getParserName
169       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
170
171     handler e = throw e
172
173     chanSource = do
174       msg <- liftIO $ readChan (messageChan context)
175       yield msg
176       chanSource
177
178     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
179     watchdog = Conduit.awaitForever $ \msg -> do
180       curId <- curTimeoutId <$> get
181       case msg of
182         ServerMessage sMsg -> yield sMsg
183         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
184
185 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
186 -- It also does not automatically send initialize and exit messages.
187 runSessionWithHandles :: Handle -- ^ Server in
188                       -> Handle -- ^ Server out
189                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
190                       -> SessionConfig
191                       -> ClientCapabilities
192                       -> FilePath -- ^ Root directory
193                       -> Session a
194                       -> IO a
195 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
196   absRootDir <- canonicalizePath rootDir
197
198   hSetBuffering serverIn  NoBuffering
199   hSetBuffering serverOut NoBuffering
200   -- This is required to make sure that we don’t get any
201   -- newline conversion or weird encoding issues.
202   hSetBinaryMode serverIn True
203   hSetBinaryMode serverOut True
204
205   reqMap <- newMVar newRequestMap
206   messageChan <- newChan
207   initRsp <- newEmptyMVar
208
209   mainThreadId <- myThreadId
210
211   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
212       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
213       launchServerHandler = forkIO $ catch (serverHandler serverOut context)
214                                            (throwTo mainThreadId :: SessionException -> IO ())
215   (result, _) <- bracket launchServerHandler killThread $
216     const $ runSession context initState session
217
218   return result
219
220 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
221 updateStateC = awaitForever $ \msg -> do
222   updateState msg
223   yield msg
224
225 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
226 updateState (NotPublishDiagnostics n) = do
227   let List diags = n ^. params . diagnostics
228       doc = n ^. params . uri
229   modify (\s ->
230     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
231       in s { curDiagnostics = newDiags })
232
233 updateState (ReqApplyWorkspaceEdit r) = do
234
235   allChangeParams <- case r ^. params . edit . documentChanges of
236     Just (List cs) -> do
237       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
238       return $ map getParams cs
239     Nothing -> case r ^. params . edit . changes of
240       Just cs -> do
241         mapM_ checkIfNeedsOpened (HashMap.keys cs)
242         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
243       Nothing -> error "No changes!"
244
245   modifyM $ \s -> do
246     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
247     return $ s { vfs = newVFS }
248
249   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
250       mergedParams = map mergeParams groupedParams
251
252   -- TODO: Don't do this when replaying a session
253   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
254
255   -- Update VFS to new document versions
256   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
257       latestVersions = map ((^. textDocument) . last) sortedVersions
258       bumpedVersions = map (version . _Just +~ 1) latestVersions
259
260   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
261     modify $ \s ->
262       let oldVFS = vfs s
263           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
264           newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
265       in s { vfs = newVFS }
266
267   where checkIfNeedsOpened uri = do
268           oldVFS <- vfs <$> get
269           ctx <- ask
270
271           -- if its not open, open it
272           unless (toNormalizedUri uri `Map.member` oldVFS) $ do
273             let fp = fromJust $ uriToFilePath uri
274             contents <- liftIO $ T.readFile fp
275             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
276                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
277             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
278
279             modifyM $ \s -> do
280               newVFS <- liftIO $ openVFS (vfs s) msg
281               return $ s { vfs = newVFS }
282
283         getParams (TextDocumentEdit docId (List edits)) =
284           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
285             in DidChangeTextDocumentParams docId (List changeEvents)
286
287         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
288
289         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
290
291         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
292
293         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
294         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
295                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
296 updateState _ = return ()
297
298 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
299 sendMessage msg = do
300   h <- serverIn <$> ask
301   logMsg LogClient msg
302   liftIO $ B.hPut h (addHeader $ encode msg)
303
304 -- | Execute a block f that will throw a 'Timeout' exception
305 -- after duration seconds. This will override the global timeout
306 -- for waiting for messages to arrive defined in 'SessionConfig'.
307 withTimeout :: Int -> Session a -> Session a
308 withTimeout duration f = do
309   chan <- asks messageChan
310   timeoutId <- curTimeoutId <$> get
311   modify $ \s -> s { overridingTimeout = True }
312   liftIO $ forkIO $ do
313     threadDelay (duration * 1000000)
314     writeChan chan (TimeoutMessage timeoutId)
315   res <- f
316   modify $ \s -> s { curTimeoutId = timeoutId + 1,
317                      overridingTimeout = False
318                    }
319   return res
320
321 data LogMsgType = LogServer | LogClient
322   deriving Eq
323
324 -- | Logs the message if the config specified it
325 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
326        => LogMsgType -> a -> m ()
327 logMsg t msg = do
328   shouldLog <- asks $ logMessages . config
329   shouldColor <- asks $ logColor . config
330   liftIO $ when shouldLog $ do
331     when shouldColor $ setSGR [SetColor Foreground Dull color]
332     putStrLn $ arrow ++ showPretty msg
333     when shouldColor $ setSGR [Reset]
334
335   where arrow
336           | t == LogServer  = "<-- "
337           | otherwise       = "--> "
338         color
339           | t == LogServer  = Magenta
340           | otherwise       = Cyan
341
342         showPretty = B.unpack . encodePretty
343