Add applyEdit and getVersionedDoc helpers
[opengl.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE FlexibleInstances #-}
3 {-# LANGUAGE MultiParamTypeClasses #-}
4 {-# LANGUAGE FlexibleContexts #-}
5
6 module Language.Haskell.LSP.Test.Session
7   ( Session
8   , SessionConfig(..)
9   , SessionMessage(..)
10   , SessionContext(..)
11   , SessionState(..)
12   , runSessionWithHandles
13   , get
14   , put
15   , modify
16   , ask
17   , asks
18   , sendMessage
19   , updateState
20   , withTimeout
21   )
22
23 where
24
25 import Control.Concurrent hiding (yield)
26 import Control.Exception
27 import Control.Lens hiding (List)
28 import Control.Monad
29 import Control.Monad.IO.Class
30 import Control.Monad.Except
31 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
32 import qualified Control.Monad.Trans.Reader as Reader (ask)
33 import Control.Monad.Trans.State (StateT, runStateT)
34 import qualified Control.Monad.Trans.State as State (get, put)
35 import qualified Data.ByteString.Lazy.Char8 as B
36 import Data.Aeson
37 import Data.Conduit as Conduit
38 import Data.Conduit.Parser as Parser
39 import Data.Default
40 import Data.Foldable
41 import Data.List
42 import qualified Data.Map as Map
43 import qualified Data.Text as T
44 import qualified Data.Text.IO as T
45 import qualified Data.HashMap.Strict as HashMap
46 import Data.Maybe
47 import Data.Function
48 import Language.Haskell.LSP.Messages
49 import Language.Haskell.LSP.TH.ClientCapabilities
50 import Language.Haskell.LSP.Types hiding (error)
51 import Language.Haskell.LSP.VFS
52 import Language.Haskell.LSP.Test.Decoding
53 import Language.Haskell.LSP.Test.Exceptions
54 import System.Console.ANSI
55 import System.Directory
56 import System.IO
57
58 -- | A session representing one instance of launching and connecting to a server.
59 -- 
60 -- You can send and receive messages to the server within 'Session' via 'getMessage',
61 -- 'sendRequest' and 'sendNotification'.
62 --
63 -- @
64 -- runSession \"path\/to\/root\/dir\" $ do
65 --   docItem <- getDocItem "Desktop/simple.hs" "haskell"
66 --   sendNotification TextDocumentDidOpen (DidOpenTextDocumentParams docItem)
67 --   diagnostics <- getMessage :: Session PublishDiagnosticsNotification
68 -- @
69 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
70
71 -- | Stuff you can configure for a 'Session'.
72 data SessionConfig = SessionConfig
73   {
74     capabilities :: ClientCapabilities -- ^ Specific capabilities the client should advertise. Default is yes to everything.
75   , messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds. Defaults to 60.
76   , logStdErr :: Bool -- ^ When True redirects the servers stderr output to haskell-lsp-test's stdout. Defaults to False
77   }
78
79 instance Default SessionConfig where
80   def = SessionConfig def 60 False
81
82 data SessionMessage = ServerMessage FromServerMessage
83                     | TimeoutMessage Int
84   deriving Show
85
86 data SessionContext = SessionContext
87   {
88     serverIn :: Handle
89   , rootDir :: FilePath
90   , messageChan :: Chan SessionMessage
91   , requestMap :: MVar RequestMap
92   , initRsp :: MVar InitializeResponse
93   , config :: SessionConfig
94   }
95
96 class Monad m => HasReader r m where
97   ask :: m r
98   asks :: (r -> b) -> m b
99   asks f = f <$> ask
100
101 instance Monad m => HasReader r (ParserStateReader a s r m) where
102   ask = lift $ lift Reader.ask
103
104 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
105   ask = lift $ lift Reader.ask
106
107 data SessionState = SessionState
108   {
109     curReqId :: LspId
110   , vfs :: VFS
111   , curDiagnostics :: Map.Map Uri [Diagnostic]
112   , curTimeoutId :: Int
113   , overridingTimeout :: Bool
114   -- ^ The last received message from the server.
115   -- Used for providing exception information
116   , lastReceivedMessage :: Maybe FromServerMessage
117   }
118
119 class Monad m => HasState s m where
120   get :: m s
121
122   put :: s -> m ()
123
124   modify :: (s -> s) -> m ()
125   modify f = get >>= put . f
126
127 instance Monad m => HasState s (ParserStateReader a s r m) where
128   get = lift State.get
129   put = lift . State.put
130
131 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
132  where
133   get = lift State.get
134   put = lift . State.put
135
136 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
137
138 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
139 runSession context state session =
140     -- source <- sourceList <$> getChanContents (messageChan context)
141     runReaderT (runStateT conduit state) context
142   where
143     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
144         
145     handler (Unexpected "ConduitParser.empty") = do
146       lastMsg <- fromJust . lastReceivedMessage <$> get
147       name <- getParserName
148       liftIO $ throw (UnexpectedMessageException (T.unpack name) lastMsg)
149
150     handler e = throw e
151
152     chanSource = do
153       msg <- liftIO $ readChan (messageChan context)
154       yield msg
155       chanSource
156
157
158     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
159     watchdog = Conduit.awaitForever $ \msg -> do
160       curId <- curTimeoutId <$> get
161       case msg of
162         ServerMessage sMsg -> yield sMsg
163         TimeoutMessage tId -> when (curId == tId) $ throw TimeoutException
164
165 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
166 -- It also does not automatically send initialize and exit messages.
167 runSessionWithHandles :: Handle -- ^ Server in
168                       -> Handle -- ^ Server out
169                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
170                       -> SessionConfig
171                       -> FilePath
172                       -> Session a
173                       -> IO a
174 runSessionWithHandles serverIn serverOut serverHandler config rootDir session = do
175   absRootDir <- canonicalizePath rootDir
176
177   hSetBuffering serverIn  NoBuffering
178   hSetBuffering serverOut NoBuffering
179
180   reqMap <- newMVar newRequestMap
181   messageChan <- newChan
182   initRsp <- newEmptyMVar
183
184   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config
185       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
186
187   threadId <- forkIO $ void $ serverHandler serverOut context
188   (result, _) <- runSession context initState session
189
190   killThread threadId
191
192   return result
193
194 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
195 updateStateC = awaitForever $ \msg -> do
196   updateState msg
197   yield msg
198
199 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
200 updateState (NotPublishDiagnostics n) = do
201   let List diags = n ^. params . diagnostics
202       doc = n ^. params . uri
203   modify (\s ->
204     let newDiags = Map.insert doc diags (curDiagnostics s)
205       in s { curDiagnostics = newDiags })
206
207 updateState (ReqApplyWorkspaceEdit r) = do
208
209   oldVFS <- vfs <$> get
210
211   allChangeParams <- case r ^. params . edit . documentChanges of
212     Just (List cs) -> do
213       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
214       return $ map getParams cs
215     Nothing -> case r ^. params . edit . changes of
216       Just cs -> do
217         mapM_ checkIfNeedsOpened (HashMap.keys cs)
218         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
219       Nothing -> error "No changes!"
220
221   newVFS <- liftIO $ changeFromServerVFS oldVFS r
222   modify (\s -> s { vfs = newVFS })
223
224   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
225       mergedParams = map mergeParams groupedParams
226
227   -- TODO: Don't do this when replaying a session
228   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
229
230   -- Update VFS to new document versions
231   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
232       latestVersions = map ((^. textDocument) . last) sortedVersions
233       bumpedVersions = map (version . _Just +~ 1) latestVersions
234
235   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
236     modify $ \s ->
237       let oldVFS = vfs s
238           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
239           newVFS = Map.adjust update uri oldVFS
240       in s { vfs = newVFS }
241
242   where checkIfNeedsOpened uri = do
243           oldVFS <- vfs <$> get
244           ctx <- ask
245
246           -- if its not open, open it
247           unless (uri `Map.member` oldVFS) $ do
248             let fp = fromJust $ uriToFilePath uri
249             contents <- liftIO $ T.readFile fp
250             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
251                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
252             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
253
254             oldVFS <- vfs <$> get
255             newVFS <- liftIO $ openVFS oldVFS msg
256             modify (\s -> s { vfs = newVFS })
257
258         getParams (TextDocumentEdit docId (List edits)) =
259           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
260             in DidChangeTextDocumentParams docId (List changeEvents)
261
262         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
263
264         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
265
266         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
267
268         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
269         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
270                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
271 updateState _ = return ()
272
273 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
274 sendMessage msg = do
275   h <- serverIn <$> ask
276   let encoded = encode msg
277   liftIO $ do
278
279     setSGR [SetColor Foreground Vivid Cyan]
280     putStrLn $ "--> " ++ B.unpack encoded
281     setSGR [Reset]
282
283     B.hPut h (addHeader encoded)
284
285 -- | Execute a block f that will throw a 'TimeoutException'
286 -- after duration seconds. This will override the global timeout
287 -- for waiting for messages to arrive defined in 'SessionConfig'.
288 withTimeout :: Int -> Session a -> Session a
289 withTimeout duration f = do
290   chan <- asks messageChan
291   timeoutId <- curTimeoutId <$> get 
292   modify $ \s -> s { overridingTimeout = True }
293   liftIO $ forkIO $ do
294     threadDelay (duration * 1000000)
295     writeChan chan (TimeoutMessage timeoutId)
296   res <- f
297   modify $ \s -> s { curTimeoutId = timeoutId + 1,
298                      overridingTimeout = False 
299                    }
300   return res