Add MonadFail constraint for GHC 8.6.1
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6
7 module Language.Haskell.LSP.Test.Session
8   ( Session
9   , SessionConfig(..)
10   , defaultConfig
11   , SessionMessage(..)
12   , SessionContext(..)
13   , SessionState(..)
14   , runSessionWithHandles
15   , get
16   , put
17   , modify
18   , modifyM
19   , ask
20   , asks
21   , sendMessage
22   , updateState
23   , withTimeout
24   , logMsg
25   , LogMsgType(..)
26   )
27
28 where
29
30 import Control.Concurrent hiding (yield)
31 import Control.Exception
32 import Control.Lens hiding (List)
33 import Control.Monad
34 import Control.Monad.IO.Class
35 import Control.Monad.Except
36 #if __GLASGOW_HASKELL__ >= 806
37 import qualified Control.Monad.Fail as Fail
38 #endif
39 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
40 import qualified Control.Monad.Trans.Reader as Reader (ask)
41 import Control.Monad.Trans.State (StateT, runStateT)
42 import qualified Control.Monad.Trans.State as State (get, put)
43 import qualified Data.ByteString.Lazy.Char8 as B
44 import Data.Aeson
45 import Data.Aeson.Encode.Pretty
46 import Data.Conduit as Conduit
47 import Data.Conduit.Parser as Parser
48 import Data.Default
49 import Data.Foldable
50 import Data.List
51 import qualified Data.Map as Map
52 import qualified Data.Text as T
53 import qualified Data.Text.IO as T
54 import qualified Data.HashMap.Strict as HashMap
55 import Data.Maybe
56 import Data.Function
57 import Language.Haskell.LSP.Messages
58 import Language.Haskell.LSP.Types.Capabilities
59 import Language.Haskell.LSP.Types
60 import Language.Haskell.LSP.Types.Lens hiding (error)
61 import Language.Haskell.LSP.VFS
62 import Language.Haskell.LSP.Test.Decoding
63 import Language.Haskell.LSP.Test.Exceptions
64 import System.Console.ANSI
65 import System.Directory
66 import System.IO
67
68 -- | A session representing one instance of launching and connecting to a server.
69 -- 
70 -- You can send and receive messages to the server within 'Session' via 'getMessage',
71 -- 'sendRequest' and 'sendNotification'.
72 --
73
74 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
75
76 -- | Stuff you can configure for a 'Session'.
77 data SessionConfig = SessionConfig
78   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
79   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
80   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
81   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
82   }
83
84 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
85 defaultConfig :: SessionConfig
86 defaultConfig = SessionConfig 60 False False True
87
88 instance Default SessionConfig where
89   def = defaultConfig
90
91 data SessionMessage = ServerMessage FromServerMessage
92                     | TimeoutMessage Int
93   deriving Show
94
95 data SessionContext = SessionContext
96   {
97     serverIn :: Handle
98   , rootDir :: FilePath
99   , messageChan :: Chan SessionMessage
100   , requestMap :: MVar RequestMap
101   , initRsp :: MVar InitializeResponse
102   , config :: SessionConfig
103   , sessionCapabilities :: ClientCapabilities
104   }
105
106 class Monad m => HasReader r m where
107   ask :: m r
108   asks :: (r -> b) -> m b
109   asks f = f <$> ask
110
111 instance Monad m => HasReader r (ParserStateReader a s r m) where
112   ask = lift $ lift Reader.ask
113
114 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
115   ask = lift $ lift Reader.ask
116
117 data SessionState = SessionState
118   {
119     curReqId :: LspId
120   , vfs :: VFS
121   , curDiagnostics :: Map.Map Uri [Diagnostic]
122   , curTimeoutId :: Int
123   , overridingTimeout :: Bool
124   -- ^ The last received message from the server.
125   -- Used for providing exception information
126   , lastReceivedMessage :: Maybe FromServerMessage
127   }
128
129 class Monad m => HasState s m where
130   get :: m s
131
132   put :: s -> m ()
133
134   modify :: (s -> s) -> m ()
135   modify f = get >>= put . f
136
137   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
138   modifyM f = get >>= f >>= put
139
140 instance Monad m => HasState s (ParserStateReader a s r m) where
141   get = lift State.get
142   put = lift . State.put
143
144 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
145  where
146   get = lift State.get
147   put = lift . State.put
148
149 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
150
151 #if __GLASGOW_HASKELL__ >= 806
152 instance (Fail.MonadFail m) => Fail.MonadFail (ParserStateReader a s r m) where
153   fail = Fail.fail
154 #endif
155
156 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
157 runSession context state session = runReaderT (runStateT conduit state) context
158   where
159     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
160         
161     handler (Unexpected "ConduitParser.empty") = do
162       lastMsg <- fromJust . lastReceivedMessage <$> get
163       name <- getParserName
164       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
165
166     handler e = throw e
167
168     chanSource = do
169       msg <- liftIO $ readChan (messageChan context)
170       yield msg
171       chanSource
172
173     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
174     watchdog = Conduit.awaitForever $ \msg -> do
175       curId <- curTimeoutId <$> get
176       case msg of
177         ServerMessage sMsg -> yield sMsg
178         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
179
180 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
181 -- It also does not automatically send initialize and exit messages.
182 runSessionWithHandles :: Handle -- ^ Server in
183                       -> Handle -- ^ Server out
184                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
185                       -> SessionConfig
186                       -> ClientCapabilities
187                       -> FilePath -- ^ Root directory
188                       -> Session a
189                       -> IO a
190 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
191   absRootDir <- canonicalizePath rootDir
192
193   hSetBuffering serverIn  NoBuffering
194   hSetBuffering serverOut NoBuffering
195
196   reqMap <- newMVar newRequestMap
197   messageChan <- newChan
198   initRsp <- newEmptyMVar
199
200   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
201       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
202
203   threadId <- forkIO $ void $ serverHandler serverOut context
204   (result, _) <- runSession context initState session
205
206   killThread threadId
207
208   return result
209
210 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
211 updateStateC = awaitForever $ \msg -> do
212   updateState msg
213   yield msg
214
215 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
216 updateState (NotPublishDiagnostics n) = do
217   let List diags = n ^. params . diagnostics
218       doc = n ^. params . uri
219   modify (\s ->
220     let newDiags = Map.insert doc diags (curDiagnostics s)
221       in s { curDiagnostics = newDiags })
222
223 updateState (ReqApplyWorkspaceEdit r) = do
224
225   allChangeParams <- case r ^. params . edit . documentChanges of
226     Just (List cs) -> do
227       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
228       return $ map getParams cs
229     Nothing -> case r ^. params . edit . changes of
230       Just cs -> do
231         mapM_ checkIfNeedsOpened (HashMap.keys cs)
232         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
233       Nothing -> error "No changes!"
234
235   modifyM $ \s -> do
236     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
237     return $ s { vfs = newVFS }
238
239   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
240       mergedParams = map mergeParams groupedParams
241
242   -- TODO: Don't do this when replaying a session
243   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
244
245   -- Update VFS to new document versions
246   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
247       latestVersions = map ((^. textDocument) . last) sortedVersions
248       bumpedVersions = map (version . _Just +~ 1) latestVersions
249
250   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
251     modify $ \s ->
252       let oldVFS = vfs s
253           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
254           newVFS = Map.adjust update uri oldVFS
255       in s { vfs = newVFS }
256
257   where checkIfNeedsOpened uri = do
258           oldVFS <- vfs <$> get
259           ctx <- ask
260
261           -- if its not open, open it
262           unless (uri `Map.member` oldVFS) $ do
263             let fp = fromJust $ uriToFilePath uri
264             contents <- liftIO $ T.readFile fp
265             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
266                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
267             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
268
269             modifyM $ \s -> do 
270               newVFS <- liftIO $ openVFS (vfs s) msg
271               return $ s { vfs = newVFS }
272
273         getParams (TextDocumentEdit docId (List edits)) =
274           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
275             in DidChangeTextDocumentParams docId (List changeEvents)
276
277         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
278
279         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
280
281         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
282
283         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
284         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
285                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
286 updateState _ = return ()
287
288 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
289 sendMessage msg = do
290   h <- serverIn <$> ask
291   logMsg LogClient msg
292   liftIO $ B.hPut h (addHeader $ encode msg)
293
294 -- | Execute a block f that will throw a 'TimeoutException'
295 -- after duration seconds. This will override the global timeout
296 -- for waiting for messages to arrive defined in 'SessionConfig'.
297 withTimeout :: Int -> Session a -> Session a
298 withTimeout duration f = do
299   chan <- asks messageChan
300   timeoutId <- curTimeoutId <$> get 
301   modify $ \s -> s { overridingTimeout = True }
302   liftIO $ forkIO $ do
303     threadDelay (duration * 1000000)
304     writeChan chan (TimeoutMessage timeoutId)
305   res <- f
306   modify $ \s -> s { curTimeoutId = timeoutId + 1,
307                      overridingTimeout = False 
308                    }
309   return res
310
311 data LogMsgType = LogServer | LogClient
312   deriving Eq
313
314 -- | Logs the message if the config specified it
315 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
316        => LogMsgType -> a -> m ()
317 logMsg t msg = do
318   shouldLog <- asks $ logMessages . config
319   shouldColor <- asks $ logColor . config
320   liftIO $ when shouldLog $ do
321     when shouldColor $ setSGR [SetColor Foreground Dull color]
322     putStrLn $ arrow ++ showPretty msg
323     when shouldColor $ setSGR [Reset]
324
325   where arrow
326           | t == LogServer  = "<-- "
327           | otherwise       = "--> "
328         color
329           | t == LogServer  = Magenta
330           | otherwise       = Cyan
331   
332         showPretty = B.unpack . encodePretty