Plug in hedgehog
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE FlexibleInstances #-}
3 {-# LANGUAGE MultiParamTypeClasses #-}
4 {-# LANGUAGE FlexibleContexts, UndecidableInstances #-}
5
6 module Language.Haskell.LSP.Test.Session
7   ( SessionT
8   , SessionConfig(..)
9   , SessionMessage(..)
10   , SessionContext(..)
11   , SessionState(..)
12   , runSessionWithHandles
13   , get
14   , put
15   , modify
16   , modifyM
17   , ask
18   , asks
19   , sendMessage
20   , updateState
21   , withTimeout
22   )
23
24 where
25
26 import Conduit
27 import Control.Concurrent hiding (yield)
28 import Control.Exception
29 import Control.Lens hiding (List)
30 import Control.Monad
31 import Control.Monad.IO.Class
32 import Control.Monad.Except
33 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
34 import qualified Control.Monad.Trans.Reader as Reader (ask)
35 import Control.Monad.Trans.State (StateT, runStateT)
36 import qualified Control.Monad.Trans.State as State (get, put)
37 import qualified Data.ByteString.Lazy.Char8 as B
38 import Data.Aeson
39 import Data.Aeson.Encode.Pretty
40 import Data.Conduit.Parser as Parser
41 import Data.Default
42 import Data.Foldable
43 import Data.List
44 import qualified Data.Map as Map
45 import qualified Data.Text as T
46 import qualified Data.Text.IO as T
47 import qualified Data.HashMap.Strict as HashMap
48 import Data.Maybe
49 import Data.Function
50 import Language.Haskell.LSP.Messages
51 import Language.Haskell.LSP.Types.Capabilities
52 import Language.Haskell.LSP.Types hiding (error)
53 import Language.Haskell.LSP.VFS
54 import Language.Haskell.LSP.Test.Decoding
55 import Language.Haskell.LSP.Test.Exceptions
56 import System.Console.ANSI
57 import System.Directory
58 import System.IO
59
60 -- | A session representing one instance of launching and connecting to a server.
61 -- 
62 -- You can send and receive messages to the server within 'Session' via 'getMessage',
63 -- 'sendRequest' and 'sendNotification'.
64 --
65 -- @
66 -- runSession \"path\/to\/root\/dir\" $ do
67 --   docItem <- getDocItem "Desktop/simple.hs" "haskell"
68 --   sendNotification TextDocumentDidOpen (DidOpenTextDocumentParams docItem)
69 --   diagnostics <- getMessage :: Session PublishDiagnosticsNotification
70 -- @
71 type SessionT m = ParserStateReader FromServerMessage SessionState SessionContext m
72
73 -- | Stuff you can configure for a 'Session'.
74 data SessionConfig = SessionConfig
75   {
76     messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds. Defaults to 60.
77   , logStdErr :: Bool -- ^ When True redirects the servers stderr output to haskell-lsp-test's stdout. Defaults to False.
78   , logMessages :: Bool -- ^ When True traces the communication between client and server to stdout. Defaults to True.
79   }
80
81 instance Default SessionConfig where
82   def = SessionConfig 60 False True
83
84 data SessionMessage = ServerMessage FromServerMessage
85                     | TimeoutMessage Int
86   deriving Show
87
88 data SessionContext = SessionContext
89   {
90     serverIn :: Handle
91   , rootDir :: FilePath
92   , messageChan :: Chan SessionMessage
93   , requestMap :: MVar RequestMap
94   , initRsp :: MVar InitializeResponse
95   , config :: SessionConfig
96   , sessionCapabilities :: ClientCapabilities
97   }
98
99 class Monad m => HasReader r m where
100   ask :: m r
101   asks :: (r -> b) -> m b
102   asks f = f <$> ask
103
104 instance Monad m => HasReader r (ParserStateReader a s r m) where
105   ask = lift $ lift Reader.ask
106
107 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
108   ask = lift $ lift Reader.ask
109
110 data SessionState = SessionState
111   {
112     curReqId :: LspId
113   , vfs :: VFS
114   , curDiagnostics :: Map.Map Uri [Diagnostic]
115   , curTimeoutId :: Int
116   , overridingTimeout :: Bool
117   -- ^ The last received message from the server.
118   -- Used for providing exception information
119   , lastReceivedMessage :: Maybe FromServerMessage
120   }
121
122 class Monad m => HasState s m where
123   get :: m s
124
125   put :: s -> m ()
126
127   modify :: (s -> s) -> m ()
128   modify f = get >>= put . f
129
130   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
131   modifyM f = get >>= f >>= put
132
133 instance Monad m => HasState s (ParserStateReader a s r m) where
134   get = lift State.get
135   put = lift . State.put
136
137 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
138  where
139   get = lift State.get
140   put = lift . State.put
141
142 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
143
144 runSession :: (MonadIO m, MonadThrow m) => SessionContext -> SessionState -> SessionT m a -> m (a, SessionState)
145 runSession context state session = runReaderT (runStateT conduit state) context
146   where
147     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
148         
149     handler :: MonadIO m => ConduitParserException -> SessionT m a
150     handler (Unexpected "ConduitParser.empty") = do
151       lastMsg <- fromJust . lastReceivedMessage <$> get
152       name <- getParserName
153       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
154
155     handler e = liftIO $ throw e
156
157     chanSource :: MonadIO m => ConduitT () SessionMessage m ()
158     chanSource = do
159       msg <- liftIO $ readChan (messageChan context)
160       yield msg
161       chanSource
162
163
164     watchdog :: MonadIO m => ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext m)) ()
165     watchdog = Conduit.awaitForever $ \msg -> do
166       curId <- curTimeoutId <$> get
167       case msg of
168         ServerMessage sMsg -> yield sMsg
169         TimeoutMessage tId -> when (curId == tId) $ throw Timeout
170
171 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
172 -- It also does not automatically send initialize and exit messages.
173 runSessionWithHandles :: (MonadIO m, MonadThrow m)
174                       => Handle -- ^ Server in
175                       -> Handle -- ^ Server out
176                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
177                       -> SessionConfig
178                       -> ClientCapabilities
179                       -> FilePath -- ^ Root directory
180                       -> SessionT m a
181                       -> m a
182 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
183   absRootDir <- liftIO $ canonicalizePath rootDir
184
185   liftIO $ do
186     hSetBuffering serverIn  NoBuffering
187     hSetBuffering serverOut NoBuffering
188
189   reqMap <- liftIO $ newMVar newRequestMap
190   messageChan <- liftIO newChan
191   initRsp <- liftIO newEmptyMVar
192
193   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
194       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
195
196   threadId <- liftIO $ forkIO $ void $ serverHandler serverOut context
197   (result, _) <- runSession context initState session
198
199   liftIO $ killThread threadId
200
201   return result
202
203 updateStateC :: MonadIO m => ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext m)) ()
204 updateStateC = awaitForever $ \msg -> do
205   updateState msg
206   yield msg
207
208 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
209 updateState (NotPublishDiagnostics n) = do
210   let List diags = n ^. params . diagnostics
211       doc = n ^. params . uri
212   modify (\s ->
213     let newDiags = Map.insert doc diags (curDiagnostics s)
214       in s { curDiagnostics = newDiags })
215
216 updateState (ReqApplyWorkspaceEdit r) = do
217
218   allChangeParams <- case r ^. params . edit . documentChanges of
219     Just (List cs) -> do
220       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
221       return $ map getParams cs
222     Nothing -> case r ^. params . edit . changes of
223       Just cs -> do
224         mapM_ checkIfNeedsOpened (HashMap.keys cs)
225         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
226       Nothing -> error "No changes!"
227
228   modifyM $ \s -> do
229     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
230     return $ s { vfs = newVFS }
231
232   let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
233       mergedParams = map mergeParams groupedParams
234
235   -- TODO: Don't do this when replaying a session
236   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
237
238   -- Update VFS to new document versions
239   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
240       latestVersions = map ((^. textDocument) . last) sortedVersions
241       bumpedVersions = map (version . _Just +~ 1) latestVersions
242
243   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
244     modify $ \s ->
245       let oldVFS = vfs s
246           update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
247           newVFS = Map.adjust update uri oldVFS
248       in s { vfs = newVFS }
249
250   where checkIfNeedsOpened uri = do
251           oldVFS <- vfs <$> get
252           ctx <- ask
253
254           -- if its not open, open it
255           unless (uri `Map.member` oldVFS) $ do
256             let fp = fromJust $ uriToFilePath uri
257             contents <- liftIO $ T.readFile fp
258             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
259                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
260             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
261
262             modifyM $ \s -> do 
263               newVFS <- liftIO $ openVFS (vfs s) msg
264               return $ s { vfs = newVFS }
265
266         getParams (TextDocumentEdit docId (List edits)) =
267           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
268             in DidChangeTextDocumentParams docId (List changeEvents)
269
270         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
271
272         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
273
274         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
275
276         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
277         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
278                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
279 updateState _ = return ()
280
281 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
282 sendMessage msg = do
283   h <- serverIn <$> ask
284   let encoded = encodePretty msg
285
286   shouldLog <- asks $ logMessages . config
287   liftIO $ when shouldLog $ do
288   
289     setSGR [SetColor Foreground Dull Cyan]
290     putStrLn $ "--> " ++ B.unpack encoded
291     setSGR [Reset]
292
293     B.hPut h (addHeader encoded)
294
295 -- | Execute a block f that will throw a 'TimeoutException'
296 -- after duration seconds. This will override the global timeout
297 -- for waiting for messages to arrive defined in 'SessionConfig'.
298 withTimeout :: MonadIO m => Int -> SessionT m a -> SessionT m a
299 withTimeout duration f = do
300   chan <- asks messageChan
301   timeoutId <- curTimeoutId <$> get 
302   modify $ \s -> s { overridingTimeout = True }
303   liftIO $ forkIO $ do
304     threadDelay (duration * 1000000)
305     writeChan chan (TimeoutMessage timeoutId)
306   res <- f
307   modify $ \s -> s { curTimeoutId = timeoutId + 1,
308                      overridingTimeout = False 
309                    }
310   return res