86030df68d2345c570e1f07884222816171026e9
[opengl.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE FlexibleInstances #-}
3 {-# LANGUAGE MultiParamTypeClasses #-}
4 {-# LANGUAGE FlexibleContexts #-}
5
6 module Language.Haskell.LSP.Test.Session
7   ( Session
8   , SessionConfig(..)
9   , defaultConfig
10   , SessionMessage(..)
11   , SessionContext(..)
12   , SessionState(..)
13   , runSessionWithHandles
14   , get
15   , put
16   , modify
17   , modifyM
18   , ask
19   , asks
20   , sendMessage
21   , updateState
22   , withTimeout
23   , logMsg
24   , LogMsgType(..)
25   )
26
27 where
28
29 import Control.Concurrent hiding (yield)
30 import Control.Exception
31 import Control.Lens hiding (List)
32 import Control.Monad
33 import Control.Monad.IO.Class
34 import Control.Monad.Except
35 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
36 import qualified Control.Monad.Trans.Reader as Reader (ask)
37 import Control.Monad.Trans.State (StateT, runStateT)
38 import qualified Control.Monad.Trans.State as State (get, put)
39 import qualified Data.ByteString.Lazy.Char8 as B
40 import Data.Aeson
41 import Data.Aeson.Encode.Pretty
42 import Data.Conduit as Conduit
43 import Data.Conduit.Parser as Parser
44 import Data.Default
45 import Data.Foldable
46 import Data.List
47 import qualified Data.Map as Map
48 import qualified Data.Text as T
49 import qualified Data.Text.IO as T
50 import qualified Data.HashMap.Strict as HashMap
51 import Data.Maybe
52 import Data.Function
53 import Language.Haskell.LSP.Messages
54 import Language.Haskell.LSP.Types.Capabilities
55 import Language.Haskell.LSP.Types hiding (error)
56 import Language.Haskell.LSP.VFS
57 import Language.Haskell.LSP.Test.Decoding
58 import Language.Haskell.LSP.Test.Exceptions
59 import System.Console.ANSI
60 import System.Directory
61 import System.IO
62
63 -- | A session representing one instance of launching and connecting to a server.
64 -- 
65 -- You can send and receive messages to the server within 'Session' via 'getMessage',
66 -- 'sendRequest' and 'sendNotification'.
67 --
68 -- @
69 -- runSession \"path\/to\/root\/dir\" $ do
70 --   docItem <- getDocItem "Desktop/simple.hs" "haskell"
71 --   sendNotification TextDocumentDidOpen (DidOpenTextDocumentParams docItem)
72 --   diagnostics <- getMessage :: Session PublishDiagnosticsNotification
73 -- @
74 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
75
76 -- | Stuff you can configure for a 'Session'.
77 data SessionConfig = SessionConfig
78   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
79   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
80   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
81   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
82   }
83
84 defaultConfig :: SessionConfig
85 defaultConfig = SessionConfig 60 False True True
86
87 instance Default SessionConfig where
88   def = defaultConfig
89
90 data SessionMessage = ServerMessage FromServerMessage
91                     | TimeoutMessage Int
92   deriving Show
93
94 data SessionContext = SessionContext
95   {
96     serverIn :: Handle
97   , rootDir :: FilePath
98   , messageChan :: Chan SessionMessage
99   , requestMap :: MVar RequestMap
100   , initRsp :: MVar InitializeResponse
101   , config :: SessionConfig
102   , sessionCapabilities :: ClientCapabilities
103   }
104
105 class Monad m => HasReader r m where
106   ask :: m r
107   asks :: (r -> b) -> m b
108   asks f = f <$> ask
109
110 instance Monad m => HasReader r (ParserStateReader a s r m) where
111   ask = lift $ lift Reader.ask
112
113 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
114   ask = lift $ lift Reader.ask
115
116 data SessionState = SessionState
117   {
118     curReqId :: LspId
119   , vfs :: VFS
120   , curDiagnostics :: Map.Map Uri [Diagnostic]
121   , curTimeoutId :: Int
122   , overridingTimeout :: Bool
123   -- ^ The last received message from the server.
124   -- Used for providing exception information
125   , lastReceivedMessage :: Maybe FromServerMessage
126   }
127
128 class Monad m => HasState s m where
129   get :: m s
130
131   put :: s -> m ()
132
133   modify :: (s -> s) -> m ()
134   modify f = get >>= put . f
135
136   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
137   modifyM f = get >>= f >>= put
138
139 instance Monad m => HasState s (ParserStateReader a s r m) where
140   get = lift State.get
141   put = lift . State.put
142
143 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
144  where
145   get = lift State.get
146   put = lift . State.put
147
148 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
149
150 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
151 runSession context state session = runReaderT (runStateT conduit state) context
152   where
153     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
154         
155     handler (Unexpected "ConduitParser.empty") = do
156       lastMsg <- fromJust . lastReceivedMessage <$> get
157       name <- getParserName
158       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
159
160     handler e = throw e
161
162     chanSource = do
163       msg <- liftIO $ readChan (messageChan context)
164       yield msg
165       chanSource
166
167     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
168     watchdog = Conduit.awaitForever $ \msg -> do
169       curId <- curTimeoutId <$> get
170       case msg of
171         ServerMessage sMsg -> yield sMsg
172         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
173
174 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
175 -- It also does not automatically send initialize and exit messages.
176 runSessionWithHandles :: Handle -- ^ Server in
177                       -> Handle -- ^ Server out
178                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
179                       -> SessionConfig
180                       -> ClientCapabilities
181                       -> FilePath -- ^ Root directory
182                       -> Session a
183                       -> IO a
184 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
185   absRootDir <- canonicalizePath rootDir
186
187   hSetBuffering serverIn  NoBuffering
188   hSetBuffering serverOut NoBuffering
189
190   reqMap <- newMVar newRequestMap
191   messageChan <- newChan
192   initRsp <- newEmptyMVar
193
194   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
195       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
196
197   threadId <- forkIO $ void $ serverHandler serverOut context
198   (result, _) <- runSession context initState session
199
200   killThread threadId
201
202   return result
203
204 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
205 updateStateC = awaitForever $ \msg -> do
206   updateState msg
207   yield msg
208
209 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
210 updateState (NotPublishDiagnostics n) = do
211   let List diags = n ^. params . diagnostics
212       doc = n ^. params . uri
213   modify (\s ->
214     let newDiags = Map.insert doc diags (curDiagnostics s)
215       in s { curDiagnostics = newDiags })
216
217 updateState (ReqApplyWorkspaceEdit r) = do
218
219   allChangeParams <- case r ^. params . edit . documentChanges of
220     Just (List cs) -> do
221       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
222       return $ map getParams cs
223     Nothing -> case r ^. params . edit . changes of
224       Just cs -> do
225         mapM_ checkIfNeedsOpened (HashMap.keys cs)
226         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
227       Nothing -> error "No changes!"
228
229   modifyM $ \s -> do
230     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
231     return $ s { vfs = newVFS }
232
233   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
234       mergedParams = map mergeParams groupedParams
235
236   -- TODO: Don't do this when replaying a session
237   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
238
239   -- Update VFS to new document versions
240   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
241       latestVersions = map ((^. textDocument) . last) sortedVersions
242       bumpedVersions = map (version . _Just +~ 1) latestVersions
243
244   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
245     modify $ \s ->
246       let oldVFS = vfs s
247           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
248           newVFS = Map.adjust update uri oldVFS
249       in s { vfs = newVFS }
250
251   where checkIfNeedsOpened uri = do
252           oldVFS <- vfs <$> get
253           ctx <- ask
254
255           -- if its not open, open it
256           unless (uri `Map.member` oldVFS) $ do
257             let fp = fromJust $ uriToFilePath uri
258             contents <- liftIO $ T.readFile fp
259             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
260                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
261             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
262
263             modifyM $ \s -> do 
264               newVFS <- liftIO $ openVFS (vfs s) msg
265               return $ s { vfs = newVFS }
266
267         getParams (TextDocumentEdit docId (List edits)) =
268           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
269             in DidChangeTextDocumentParams docId (List changeEvents)
270
271         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
272
273         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
274
275         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
276
277         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
278         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
279                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
280 updateState _ = return ()
281
282 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
283 sendMessage msg = do
284   h <- serverIn <$> ask
285   logMsg LogClient msg
286   liftIO $ B.hPut h (addHeader $ encode msg)
287
288 -- | Execute a block f that will throw a 'TimeoutException'
289 -- after duration seconds. This will override the global timeout
290 -- for waiting for messages to arrive defined in 'SessionConfig'.
291 withTimeout :: Int -> Session a -> Session a
292 withTimeout duration f = do
293   chan <- asks messageChan
294   timeoutId <- curTimeoutId <$> get 
295   modify $ \s -> s { overridingTimeout = True }
296   liftIO $ forkIO $ do
297     threadDelay (duration * 1000000)
298     writeChan chan (TimeoutMessage timeoutId)
299   res <- f
300   modify $ \s -> s { curTimeoutId = timeoutId + 1,
301                      overridingTimeout = False 
302                    }
303   return res
304
305 data LogMsgType = LogServer | LogClient
306   deriving Eq
307
308 -- | Logs the message if the config specified it
309 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
310        => LogMsgType -> a -> m ()
311 logMsg t msg = do
312   shouldLog <- asks $ logMessages . config
313   shouldColor <- asks $ logColor . config
314   liftIO $ when shouldLog $ do
315     when shouldColor $ setSGR [SetColor Foreground Dull color]
316     putStrLn $ arrow ++ showPretty msg
317     when shouldColor $ setSGR [Reset]
318
319   where arrow
320           | t == LogServer  = "<-- "
321           | otherwise       = "--> "
322         color
323           | t == LogServer  = Magenta
324           | otherwise       = Cyan
325   
326         showPretty = B.unpack . encodePretty