Update to haskell-lsp-0.7
[opengl.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE FlexibleInstances #-}
3 {-# LANGUAGE MultiParamTypeClasses #-}
4 {-# LANGUAGE FlexibleContexts #-}
5
6 module Language.Haskell.LSP.Test.Session
7   ( Session
8   , SessionConfig(..)
9   , defaultConfig
10   , SessionMessage(..)
11   , SessionContext(..)
12   , SessionState(..)
13   , runSessionWithHandles
14   , get
15   , put
16   , modify
17   , modifyM
18   , ask
19   , asks
20   , sendMessage
21   , updateState
22   , withTimeout
23   , logMsg
24   , LogMsgType(..)
25   )
26
27 where
28
29 import Control.Concurrent hiding (yield)
30 import Control.Exception
31 import Control.Lens hiding (List)
32 import Control.Monad
33 import Control.Monad.IO.Class
34 import Control.Monad.Except
35 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
36 import qualified Control.Monad.Trans.Reader as Reader (ask)
37 import Control.Monad.Trans.State (StateT, runStateT)
38 import qualified Control.Monad.Trans.State as State (get, put)
39 import qualified Data.ByteString.Lazy.Char8 as B
40 import Data.Aeson
41 import Data.Aeson.Encode.Pretty
42 import Data.Conduit as Conduit
43 import Data.Conduit.Parser as Parser
44 import Data.Default
45 import Data.Foldable
46 import Data.List
47 import qualified Data.Map as Map
48 import qualified Data.Text as T
49 import qualified Data.Text.IO as T
50 import qualified Data.HashMap.Strict as HashMap
51 import Data.Maybe
52 import Data.Function
53 import Language.Haskell.LSP.Messages
54 import Language.Haskell.LSP.Types.Capabilities
55 import Language.Haskell.LSP.Types hiding (error)
56 import Language.Haskell.LSP.VFS
57 import Language.Haskell.LSP.Test.Decoding
58 import Language.Haskell.LSP.Test.Exceptions
59 import System.Console.ANSI
60 import System.Directory
61 import System.IO
62
63 -- | A session representing one instance of launching and connecting to a server.
64 -- 
65 -- You can send and receive messages to the server within 'Session' via 'getMessage',
66 -- 'sendRequest' and 'sendNotification'.
67 --
68
69 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
70
71 -- | Stuff you can configure for a 'Session'.
72 data SessionConfig = SessionConfig
73   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
74   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
75   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
76   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
77   }
78
79 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
80 defaultConfig :: SessionConfig
81 defaultConfig = SessionConfig 60 False False True
82
83 instance Default SessionConfig where
84   def = defaultConfig
85
86 data SessionMessage = ServerMessage FromServerMessage
87                     | TimeoutMessage Int
88   deriving Show
89
90 data SessionContext = SessionContext
91   {
92     serverIn :: Handle
93   , rootDir :: FilePath
94   , messageChan :: Chan SessionMessage
95   , requestMap :: MVar RequestMap
96   , initRsp :: MVar InitializeResponse
97   , config :: SessionConfig
98   , sessionCapabilities :: ClientCapabilities
99   }
100
101 class Monad m => HasReader r m where
102   ask :: m r
103   asks :: (r -> b) -> m b
104   asks f = f <$> ask
105
106 instance Monad m => HasReader r (ParserStateReader a s r m) where
107   ask = lift $ lift Reader.ask
108
109 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
110   ask = lift $ lift Reader.ask
111
112 data SessionState = SessionState
113   {
114     curReqId :: LspId
115   , vfs :: VFS
116   , curDiagnostics :: Map.Map Uri [Diagnostic]
117   , curTimeoutId :: Int
118   , overridingTimeout :: Bool
119   -- ^ The last received message from the server.
120   -- Used for providing exception information
121   , lastReceivedMessage :: Maybe FromServerMessage
122   }
123
124 class Monad m => HasState s m where
125   get :: m s
126
127   put :: s -> m ()
128
129   modify :: (s -> s) -> m ()
130   modify f = get >>= put . f
131
132   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
133   modifyM f = get >>= f >>= put
134
135 instance Monad m => HasState s (ParserStateReader a s r m) where
136   get = lift State.get
137   put = lift . State.put
138
139 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
140  where
141   get = lift State.get
142   put = lift . State.put
143
144 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
145
146 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
147 runSession context state session = runReaderT (runStateT conduit state) context
148   where
149     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
150         
151     handler (Unexpected "ConduitParser.empty") = do
152       lastMsg <- fromJust . lastReceivedMessage <$> get
153       name <- getParserName
154       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
155
156     handler e = throw e
157
158     chanSource = do
159       msg <- liftIO $ readChan (messageChan context)
160       yield msg
161       chanSource
162
163     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
164     watchdog = Conduit.awaitForever $ \msg -> do
165       curId <- curTimeoutId <$> get
166       case msg of
167         ServerMessage sMsg -> yield sMsg
168         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
169
170 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
171 -- It also does not automatically send initialize and exit messages.
172 runSessionWithHandles :: Handle -- ^ Server in
173                       -> Handle -- ^ Server out
174                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
175                       -> SessionConfig
176                       -> ClientCapabilities
177                       -> FilePath -- ^ Root directory
178                       -> Session a
179                       -> IO a
180 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
181   absRootDir <- canonicalizePath rootDir
182
183   hSetBuffering serverIn  NoBuffering
184   hSetBuffering serverOut NoBuffering
185
186   reqMap <- newMVar newRequestMap
187   messageChan <- newChan
188   initRsp <- newEmptyMVar
189
190   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
191       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
192
193   threadId <- forkIO $ void $ serverHandler serverOut context
194   (result, _) <- runSession context initState session
195
196   killThread threadId
197
198   return result
199
200 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
201 updateStateC = awaitForever $ \msg -> do
202   updateState msg
203   yield msg
204
205 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
206 updateState (NotPublishDiagnostics n) = do
207   let List diags = n ^. params . diagnostics
208       doc = n ^. params . uri
209   modify (\s ->
210     let newDiags = Map.insert doc diags (curDiagnostics s)
211       in s { curDiagnostics = newDiags })
212
213 updateState (ReqApplyWorkspaceEdit r) = do
214
215   allChangeParams <- case r ^. params . edit . documentChanges of
216     Just (List cs) -> do
217       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
218       return $ map getParams cs
219     Nothing -> case r ^. params . edit . changes of
220       Just cs -> do
221         mapM_ checkIfNeedsOpened (HashMap.keys cs)
222         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
223       Nothing -> error "No changes!"
224
225   modifyM $ \s -> do
226     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
227     return $ s { vfs = newVFS }
228
229   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
230       mergedParams = map mergeParams groupedParams
231
232   -- TODO: Don't do this when replaying a session
233   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
234
235   -- Update VFS to new document versions
236   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
237       latestVersions = map ((^. textDocument) . last) sortedVersions
238       bumpedVersions = map (version . _Just +~ 1) latestVersions
239
240   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
241     modify $ \s ->
242       let oldVFS = vfs s
243           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
244           newVFS = Map.adjust update uri oldVFS
245       in s { vfs = newVFS }
246
247   where checkIfNeedsOpened uri = do
248           oldVFS <- vfs <$> get
249           ctx <- ask
250
251           -- if its not open, open it
252           unless (uri `Map.member` oldVFS) $ do
253             let fp = fromJust $ uriToFilePath uri
254             contents <- liftIO $ T.readFile fp
255             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
256                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
257             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
258
259             modifyM $ \s -> do 
260               newVFS <- liftIO $ openVFS (vfs s) msg
261               return $ s { vfs = newVFS }
262
263         getParams (TextDocumentEdit docId (List edits)) =
264           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
265             in DidChangeTextDocumentParams docId (List changeEvents)
266
267         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
268
269         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
270
271         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
272
273         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
274         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
275                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
276 updateState _ = return ()
277
278 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
279 sendMessage msg = do
280   h <- serverIn <$> ask
281   logMsg LogClient msg
282   liftIO $ B.hPut h (addHeader $ encode msg)
283
284 -- | Execute a block f that will throw a 'TimeoutException'
285 -- after duration seconds. This will override the global timeout
286 -- for waiting for messages to arrive defined in 'SessionConfig'.
287 withTimeout :: Int -> Session a -> Session a
288 withTimeout duration f = do
289   chan <- asks messageChan
290   timeoutId <- curTimeoutId <$> get 
291   modify $ \s -> s { overridingTimeout = True }
292   liftIO $ forkIO $ do
293     threadDelay (duration * 1000000)
294     writeChan chan (TimeoutMessage timeoutId)
295   res <- f
296   modify $ \s -> s { curTimeoutId = timeoutId + 1,
297                      overridingTimeout = False 
298                    }
299   return res
300
301 data LogMsgType = LogServer | LogClient
302   deriving Eq
303
304 -- | Logs the message if the config specified it
305 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
306        => LogMsgType -> a -> m ()
307 logMsg t msg = do
308   shouldLog <- asks $ logMessages . config
309   shouldColor <- asks $ logColor . config
310   liftIO $ when shouldLog $ do
311     when shouldColor $ setSGR [SetColor Foreground Dull color]
312     putStrLn $ arrow ++ showPretty msg
313     when shouldColor $ setSGR [Reset]
314
315   where arrow
316           | t == LogServer  = "<-- "
317           | otherwise       = "--> "
318         color
319           | t == LogServer  = Magenta
320           | otherwise       = Cyan
321   
322         showPretty = B.unpack . encodePretty