Rename sendRequest to request, sendRequest' to sendRequest
[opengl.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE FlexibleInstances #-}
3 {-# LANGUAGE MultiParamTypeClasses #-}
4 {-# LANGUAGE FlexibleContexts #-}
5
6 module Language.Haskell.LSP.Test.Session
7   ( Session
8   , SessionConfig(..)
9   , defaultConfig
10   , SessionMessage(..)
11   , SessionContext(..)
12   , SessionState(..)
13   , runSessionWithHandles
14   , get
15   , put
16   , modify
17   , modifyM
18   , ask
19   , asks
20   , sendMessage
21   , updateState
22   , withTimeout
23   , logMsg
24   , LogMsgType(..)
25   )
26
27 where
28
29 import Control.Concurrent hiding (yield)
30 import Control.Exception
31 import Control.Lens hiding (List)
32 import Control.Monad
33 import Control.Monad.IO.Class
34 import Control.Monad.Except
35 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
36 import qualified Control.Monad.Trans.Reader as Reader (ask)
37 import Control.Monad.Trans.State (StateT, runStateT)
38 import qualified Control.Monad.Trans.State as State (get, put)
39 import qualified Data.ByteString.Lazy.Char8 as B
40 import Data.Aeson
41 import Data.Aeson.Encode.Pretty
42 import Data.Conduit as Conduit
43 import Data.Conduit.Parser as Parser
44 import Data.Default
45 import Data.Foldable
46 import Data.List
47 import qualified Data.Map as Map
48 import qualified Data.Text as T
49 import qualified Data.Text.IO as T
50 import qualified Data.HashMap.Strict as HashMap
51 import Data.Maybe
52 import Data.Function
53 import Language.Haskell.LSP.Messages
54 import Language.Haskell.LSP.Types.Capabilities
55 import Language.Haskell.LSP.Types hiding (error)
56 import Language.Haskell.LSP.VFS
57 import Language.Haskell.LSP.Test.Decoding
58 import Language.Haskell.LSP.Test.Exceptions
59 import System.Console.ANSI
60 import System.Directory
61 import System.IO
62
63 -- | A session representing one instance of launching and connecting to a server.
64 -- 
65 -- You can send and receive messages to the server within 'Session' via 'getMessage',
66 -- 'sendRequest' and 'sendNotification'.
67 --
68 -- @
69 -- runSession \"path\/to\/root\/dir\" $ do
70 --   docItem <- getDocItem "Desktop/simple.hs" "haskell"
71 --   sendNotification TextDocumentDidOpen (DidOpenTextDocumentParams docItem)
72 --   diagnostics <- getMessage :: Session PublishDiagnosticsNotification
73 -- @
74 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
75
76 -- | Stuff you can configure for a 'Session'.
77 data SessionConfig = SessionConfig
78   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
79   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
80   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
81   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
82   }
83
84 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
85 defaultConfig :: SessionConfig
86 defaultConfig = SessionConfig 60 False True True
87
88 instance Default SessionConfig where
89   def = defaultConfig
90
91 data SessionMessage = ServerMessage FromServerMessage
92                     | TimeoutMessage Int
93   deriving Show
94
95 data SessionContext = SessionContext
96   {
97     serverIn :: Handle
98   , rootDir :: FilePath
99   , messageChan :: Chan SessionMessage
100   , requestMap :: MVar RequestMap
101   , initRsp :: MVar InitializeResponse
102   , config :: SessionConfig
103   , sessionCapabilities :: ClientCapabilities
104   }
105
106 class Monad m => HasReader r m where
107   ask :: m r
108   asks :: (r -> b) -> m b
109   asks f = f <$> ask
110
111 instance Monad m => HasReader r (ParserStateReader a s r m) where
112   ask = lift $ lift Reader.ask
113
114 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
115   ask = lift $ lift Reader.ask
116
117 data SessionState = SessionState
118   {
119     curReqId :: LspId
120   , vfs :: VFS
121   , curDiagnostics :: Map.Map Uri [Diagnostic]
122   , curTimeoutId :: Int
123   , overridingTimeout :: Bool
124   -- ^ The last received message from the server.
125   -- Used for providing exception information
126   , lastReceivedMessage :: Maybe FromServerMessage
127   }
128
129 class Monad m => HasState s m where
130   get :: m s
131
132   put :: s -> m ()
133
134   modify :: (s -> s) -> m ()
135   modify f = get >>= put . f
136
137   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
138   modifyM f = get >>= f >>= put
139
140 instance Monad m => HasState s (ParserStateReader a s r m) where
141   get = lift State.get
142   put = lift . State.put
143
144 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
145  where
146   get = lift State.get
147   put = lift . State.put
148
149 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
150
151 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
152 runSession context state session = runReaderT (runStateT conduit state) context
153   where
154     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
155         
156     handler (Unexpected "ConduitParser.empty") = do
157       lastMsg <- fromJust . lastReceivedMessage <$> get
158       name <- getParserName
159       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
160
161     handler e = throw e
162
163     chanSource = do
164       msg <- liftIO $ readChan (messageChan context)
165       yield msg
166       chanSource
167
168     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
169     watchdog = Conduit.awaitForever $ \msg -> do
170       curId <- curTimeoutId <$> get
171       case msg of
172         ServerMessage sMsg -> yield sMsg
173         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
174
175 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
176 -- It also does not automatically send initialize and exit messages.
177 runSessionWithHandles :: Handle -- ^ Server in
178                       -> Handle -- ^ Server out
179                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
180                       -> SessionConfig
181                       -> ClientCapabilities
182                       -> FilePath -- ^ Root directory
183                       -> Session a
184                       -> IO a
185 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
186   absRootDir <- canonicalizePath rootDir
187
188   hSetBuffering serverIn  NoBuffering
189   hSetBuffering serverOut NoBuffering
190
191   reqMap <- newMVar newRequestMap
192   messageChan <- newChan
193   initRsp <- newEmptyMVar
194
195   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
196       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
197
198   threadId <- forkIO $ void $ serverHandler serverOut context
199   (result, _) <- runSession context initState session
200
201   killThread threadId
202
203   return result
204
205 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
206 updateStateC = awaitForever $ \msg -> do
207   updateState msg
208   yield msg
209
210 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
211 updateState (NotPublishDiagnostics n) = do
212   let List diags = n ^. params . diagnostics
213       doc = n ^. params . uri
214   modify (\s ->
215     let newDiags = Map.insert doc diags (curDiagnostics s)
216       in s { curDiagnostics = newDiags })
217
218 updateState (ReqApplyWorkspaceEdit r) = do
219
220   allChangeParams <- case r ^. params . edit . documentChanges of
221     Just (List cs) -> do
222       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
223       return $ map getParams cs
224     Nothing -> case r ^. params . edit . changes of
225       Just cs -> do
226         mapM_ checkIfNeedsOpened (HashMap.keys cs)
227         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
228       Nothing -> error "No changes!"
229
230   modifyM $ \s -> do
231     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
232     return $ s { vfs = newVFS }
233
234   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
235       mergedParams = map mergeParams groupedParams
236
237   -- TODO: Don't do this when replaying a session
238   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
239
240   -- Update VFS to new document versions
241   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
242       latestVersions = map ((^. textDocument) . last) sortedVersions
243       bumpedVersions = map (version . _Just +~ 1) latestVersions
244
245   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
246     modify $ \s ->
247       let oldVFS = vfs s
248           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
249           newVFS = Map.adjust update uri oldVFS
250       in s { vfs = newVFS }
251
252   where checkIfNeedsOpened uri = do
253           oldVFS <- vfs <$> get
254           ctx <- ask
255
256           -- if its not open, open it
257           unless (uri `Map.member` oldVFS) $ do
258             let fp = fromJust $ uriToFilePath uri
259             contents <- liftIO $ T.readFile fp
260             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
261                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
262             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
263
264             modifyM $ \s -> do 
265               newVFS <- liftIO $ openVFS (vfs s) msg
266               return $ s { vfs = newVFS }
267
268         getParams (TextDocumentEdit docId (List edits)) =
269           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
270             in DidChangeTextDocumentParams docId (List changeEvents)
271
272         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
273
274         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
275
276         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
277
278         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
279         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
280                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
281 updateState _ = return ()
282
283 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
284 sendMessage msg = do
285   h <- serverIn <$> ask
286   logMsg LogClient msg
287   liftIO $ B.hPut h (addHeader $ encode msg)
288
289 -- | Execute a block f that will throw a 'TimeoutException'
290 -- after duration seconds. This will override the global timeout
291 -- for waiting for messages to arrive defined in 'SessionConfig'.
292 withTimeout :: Int -> Session a -> Session a
293 withTimeout duration f = do
294   chan <- asks messageChan
295   timeoutId <- curTimeoutId <$> get 
296   modify $ \s -> s { overridingTimeout = True }
297   liftIO $ forkIO $ do
298     threadDelay (duration * 1000000)
299     writeChan chan (TimeoutMessage timeoutId)
300   res <- f
301   modify $ \s -> s { curTimeoutId = timeoutId + 1,
302                      overridingTimeout = False 
303                    }
304   return res
305
306 data LogMsgType = LogServer | LogClient
307   deriving Eq
308
309 -- | Logs the message if the config specified it
310 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
311        => LogMsgType -> a -> m ()
312 logMsg t msg = do
313   shouldLog <- asks $ logMessages . config
314   shouldColor <- asks $ logColor . config
315   liftIO $ when shouldLog $ do
316     when shouldColor $ setSGR [SetColor Foreground Dull color]
317     putStrLn $ arrow ++ showPretty msg
318     when shouldColor $ setSGR [Reset]
319
320   where arrow
321           | t == LogServer  = "<-- "
322           | otherwise       = "--> "
323         color
324           | t == LogServer  = Magenta
325           | otherwise       = Cyan
326   
327         showPretty = B.unpack . encodePretty