Add lspConfig option
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE FlexibleInstances #-}
3 {-# LANGUAGE MultiParamTypeClasses #-}
4 {-# LANGUAGE FlexibleContexts #-}
5 {-# LANGUAGE RankNTypes #-}
6
7 module Language.Haskell.LSP.Test.Session
8   ( Session
9   , SessionConfig(..)
10   , defaultConfig
11   , SessionMessage(..)
12   , SessionContext(..)
13   , SessionState(..)
14   , runSessionWithHandles
15   , get
16   , put
17   , modify
18   , modifyM
19   , ask
20   , asks
21   , sendMessage
22   , updateState
23   , withTimeout
24   , logMsg
25   , LogMsgType(..)
26   )
27
28 where
29
30 import Control.Concurrent hiding (yield)
31 import Control.Exception
32 import Control.Lens hiding (List)
33 import Control.Monad
34 import Control.Monad.Fail
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
38 import qualified Control.Monad.Trans.Reader as Reader (ask)
39 import Control.Monad.Trans.State (StateT, runStateT)
40 import qualified Control.Monad.Trans.State as State (get, put)
41 import qualified Data.ByteString.Lazy.Char8 as B
42 import Data.Aeson
43 import Data.Aeson.Encode.Pretty
44 import Data.Conduit as Conduit
45 import Data.Conduit.Parser as Parser
46 import Data.Default
47 import Data.Foldable
48 import Data.List
49 import qualified Data.Map as Map
50 import qualified Data.Text as T
51 import qualified Data.Text.IO as T
52 import qualified Data.HashMap.Strict as HashMap
53 import Data.Maybe
54 import Data.Function
55 import Language.Haskell.LSP.Messages
56 import Language.Haskell.LSP.Types.Capabilities
57 import Language.Haskell.LSP.Types
58 import Language.Haskell.LSP.Types.Lens hiding (error)
59 import Language.Haskell.LSP.VFS
60 import Language.Haskell.LSP.Test.Decoding
61 import Language.Haskell.LSP.Test.Exceptions
62 import System.Console.ANSI
63 import System.Directory
64 import System.IO
65
66 -- | A session representing one instance of launching and connecting to a server.
67 -- 
68 -- You can send and receive messages to the server within 'Session' via 'getMessage',
69 -- 'sendRequest' and 'sendNotification'.
70 --
71
72 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
73
74 instance MonadFail Session where
75   fail s = do
76     lastMsg <- fromJust . lastReceivedMessage <$> get
77     liftIO $ throw (UnexpectedMessage s lastMsg)
78
79 -- | Stuff you can configure for a 'Session'.
80 data SessionConfig = SessionConfig
81   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
82   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
83   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
84   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
85   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
86   }
87
88 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
89 defaultConfig :: SessionConfig
90 defaultConfig = SessionConfig 60 False True True Nothing
91
92 instance Default SessionConfig where
93   def = defaultConfig
94
95 data SessionMessage = ServerMessage FromServerMessage
96                     | TimeoutMessage Int
97   deriving Show
98
99 data SessionContext = SessionContext
100   {
101     serverIn :: Handle
102   , rootDir :: FilePath
103   , messageChan :: Chan SessionMessage
104   , requestMap :: MVar RequestMap
105   , initRsp :: MVar InitializeResponse
106   , config :: SessionConfig
107   , sessionCapabilities :: ClientCapabilities
108   }
109
110 class Monad m => HasReader r m where
111   ask :: m r
112   asks :: (r -> b) -> m b
113   asks f = f <$> ask
114
115 instance Monad m => HasReader r (ParserStateReader a s r m) where
116   ask = lift $ lift Reader.ask
117
118 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
119   ask = lift $ lift Reader.ask
120
121 data SessionState = SessionState
122   {
123     curReqId :: LspId
124   , vfs :: VFS
125   , curDiagnostics :: Map.Map Uri [Diagnostic]
126   , curTimeoutId :: Int
127   , overridingTimeout :: Bool
128   -- ^ The last received message from the server.
129   -- Used for providing exception information
130   , lastReceivedMessage :: Maybe FromServerMessage
131   }
132
133 class Monad m => HasState s m where
134   get :: m s
135
136   put :: s -> m ()
137
138   modify :: (s -> s) -> m ()
139   modify f = get >>= put . f
140
141   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
142   modifyM f = get >>= f >>= put
143
144 instance Monad m => HasState s (ParserStateReader a s r m) where
145   get = lift State.get
146   put = lift . State.put
147
148 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
149  where
150   get = lift State.get
151   put = lift . State.put
152
153 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
154
155 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
156 runSession context state session = runReaderT (runStateT conduit state) context
157   where
158     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
159
160     handler (Unexpected "ConduitParser.empty") = do
161       lastMsg <- fromJust . lastReceivedMessage <$> get
162       name <- getParserName
163       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
164
165     handler e = throw e
166
167     chanSource = do
168       msg <- liftIO $ readChan (messageChan context)
169       yield msg
170       chanSource
171
172     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
173     watchdog = Conduit.awaitForever $ \msg -> do
174       curId <- curTimeoutId <$> get
175       case msg of
176         ServerMessage sMsg -> yield sMsg
177         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
178
179 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
180 -- It also does not automatically send initialize and exit messages.
181 runSessionWithHandles :: Handle -- ^ Server in
182                       -> Handle -- ^ Server out
183                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
184                       -> SessionConfig
185                       -> ClientCapabilities
186                       -> FilePath -- ^ Root directory
187                       -> Session a
188                       -> IO a
189 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
190   absRootDir <- canonicalizePath rootDir
191
192   hSetBuffering serverIn  NoBuffering
193   hSetBuffering serverOut NoBuffering
194
195   reqMap <- newMVar newRequestMap
196   messageChan <- newChan
197   initRsp <- newEmptyMVar
198
199   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
200       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
201
202   threadId <- forkIO $ void $ serverHandler serverOut context
203   (result, _) <- runSession context initState session
204
205   killThread threadId
206
207   return result
208
209 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
210 updateStateC = awaitForever $ \msg -> do
211   updateState msg
212   yield msg
213
214 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
215 updateState (NotPublishDiagnostics n) = do
216   let List diags = n ^. params . diagnostics
217       doc = n ^. params . uri
218   modify (\s ->
219     let newDiags = Map.insert doc diags (curDiagnostics s)
220       in s { curDiagnostics = newDiags })
221
222 updateState (ReqApplyWorkspaceEdit r) = do
223
224   allChangeParams <- case r ^. params . edit . documentChanges of
225     Just (List cs) -> do
226       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
227       return $ map getParams cs
228     Nothing -> case r ^. params . edit . changes of
229       Just cs -> do
230         mapM_ checkIfNeedsOpened (HashMap.keys cs)
231         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
232       Nothing -> error "No changes!"
233
234   modifyM $ \s -> do
235     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
236     return $ s { vfs = newVFS }
237
238   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
239       mergedParams = map mergeParams groupedParams
240
241   -- TODO: Don't do this when replaying a session
242   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
243
244   -- Update VFS to new document versions
245   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
246       latestVersions = map ((^. textDocument) . last) sortedVersions
247       bumpedVersions = map (version . _Just +~ 1) latestVersions
248
249   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
250     modify $ \s ->
251       let oldVFS = vfs s
252           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
253           newVFS = Map.adjust update uri oldVFS
254       in s { vfs = newVFS }
255
256   where checkIfNeedsOpened uri = do
257           oldVFS <- vfs <$> get
258           ctx <- ask
259
260           -- if its not open, open it
261           unless (uri `Map.member` oldVFS) $ do
262             let fp = fromJust $ uriToFilePath uri
263             contents <- liftIO $ T.readFile fp
264             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
265                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
266             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
267
268             modifyM $ \s -> do
269               newVFS <- liftIO $ openVFS (vfs s) msg
270               return $ s { vfs = newVFS }
271
272         getParams (TextDocumentEdit docId (List edits)) =
273           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
274             in DidChangeTextDocumentParams docId (List changeEvents)
275
276         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
277
278         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
279
280         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
281
282         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
283         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
284                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
285 updateState _ = return ()
286
287 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
288 sendMessage msg = do
289   h <- serverIn <$> ask
290   logMsg LogClient msg
291   liftIO $ B.hPut h (addHeader $ encode msg)
292
293 -- | Execute a block f that will throw a 'TimeoutException'
294 -- after duration seconds. This will override the global timeout
295 -- for waiting for messages to arrive defined in 'SessionConfig'.
296 withTimeout :: Int -> Session a -> Session a
297 withTimeout duration f = do
298   chan <- asks messageChan
299   timeoutId <- curTimeoutId <$> get
300   modify $ \s -> s { overridingTimeout = True }
301   liftIO $ forkIO $ do
302     threadDelay (duration * 1000000)
303     writeChan chan (TimeoutMessage timeoutId)
304   res <- f
305   modify $ \s -> s { curTimeoutId = timeoutId + 1,
306                      overridingTimeout = False
307                    }
308   return res
309
310 data LogMsgType = LogServer | LogClient
311   deriving Eq
312
313 -- | Logs the message if the config specified it
314 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
315        => LogMsgType -> a -> m ()
316 logMsg t msg = do
317   shouldLog <- asks $ logMessages . config
318   shouldColor <- asks $ logColor . config
319   liftIO $ when shouldLog $ do
320     when shouldColor $ setSGR [SetColor Foreground Dull color]
321     putStrLn $ arrow ++ showPretty msg
322     when shouldColor $ setSGR [Reset]
323
324   where arrow
325           | t == LogServer  = "<-- "
326           | otherwise       = "--> "
327         color
328           | t == LogServer  = Magenta
329           | otherwise       = Cyan
330
331         showPretty = B.unpack . encodePretty