Add lspConfig option
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index b38d1b7fe1d4885de1354036be174aa10f484c93..f0d410afed37c295a9ebdac061471056f1bb882b 100644 (file)
@@ -1,18 +1,29 @@
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
   ( Session
   , SessionConfig(..)
 
 module Language.Haskell.LSP.Test.Session
   ( Session
   , SessionConfig(..)
+  , defaultConfig
+  , SessionMessage(..)
   , SessionContext(..)
   , SessionState(..)
   , SessionContext(..)
   , SessionState(..)
-  , MonadSessionConfig(..)
   , runSessionWithHandles
   , get
   , put
   , modify
   , modifyM
   , runSessionWithHandles
   , get
   , put
   , modify
   , modifyM
-  , ask)
+  , ask
+  , asks
+  , sendMessage
+  , updateState
+  , withTimeout
+  , logMsg
+  , LogMsgType(..)
+  )
 
 where
 
 
 where
 
@@ -20,16 +31,18 @@ import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
+import Control.Monad.Fail
 import Control.Monad.IO.Class
 import Control.Monad.Except
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
 import Control.Monad.IO.Class
 import Control.Monad.Except
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
-import qualified Control.Monad.Trans.State as State (get, put, modify)
+import qualified Control.Monad.Trans.State as State (get, put)
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
-import Data.Conduit hiding (await)
-import Data.Conduit.Parser
+import Data.Aeson.Encode.Pretty
+import Data.Conduit as Conduit
+import Data.Conduit.Parser as Parser
 import Data.Default
 import Data.Foldable
 import Data.List
 import Data.Default
 import Data.Foldable
 import Data.List
@@ -38,13 +51,15 @@ import qualified Data.Text as T
 import qualified Data.Text.IO as T
 import qualified Data.HashMap.Strict as HashMap
 import Data.Maybe
 import qualified Data.Text.IO as T
 import qualified Data.HashMap.Strict as HashMap
 import Data.Maybe
+import Data.Function
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Messages
-import Language.Haskell.LSP.TH.ClientCapabilities
+import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
 import Language.Haskell.LSP.Types
+import Language.Haskell.LSP.Types.Lens hiding (error)
 import Language.Haskell.LSP.VFS
 import Language.Haskell.LSP.VFS
-import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
+import System.Console.ANSI
 import System.Directory
 import System.IO
 
 import System.Directory
 import System.IO
 
@@ -53,104 +68,125 @@ import System.IO
 -- You can send and receive messages to the server within 'Session' via 'getMessage',
 -- 'sendRequest' and 'sendNotification'.
 --
 -- You can send and receive messages to the server within 'Session' via 'getMessage',
 -- 'sendRequest' and 'sendNotification'.
 --
--- @
--- runSession \"path\/to\/root\/dir\" $ do
---   docItem <- getDocItem "Desktop/simple.hs" "haskell"
---   sendNotification TextDocumentDidOpen (DidOpenTextDocumentParams docItem)
---   diagnostics <- getMessage :: Session PublishDiagnosticsNotification
--- @
+
 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
 
 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
 
+instance MonadFail Session where
+  fail s = do
+    lastMsg <- fromJust . lastReceivedMessage <$> get
+    liftIO $ throw (UnexpectedMessage s lastMsg)
+
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
-  {
-    capabilities :: ClientCapabilities -- ^ Specific capabilities the client should advertise. Default is yes to everything.
-  , timeout :: Int -- ^ Maximum time to wait for a request in seconds. Defaults to 60.
-  , logStdErr :: Bool -- ^ When True redirects the servers stderr output to haskell-lsp-test's stdout. Defaults to False
+  { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
+  , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
+  , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
+  , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
+  , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
   }
 
   }
 
-instance Default SessionConfig where
-  def = SessionConfig def 60 False
+-- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
+defaultConfig :: SessionConfig
+defaultConfig = SessionConfig 60 False True True Nothing
 
 
-class Monad m => MonadSessionConfig m where
-  sessionConfig :: m SessionConfig
+instance Default SessionConfig where
+  def = defaultConfig
 
 
-instance Monad m => MonadSessionConfig (StateT SessionState (ReaderT SessionContext m)) where
-  sessionConfig = config <$> lift Reader.ask
+data SessionMessage = ServerMessage FromServerMessage
+                    | TimeoutMessage Int
+  deriving Show
 
 data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
 
 data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
-  , messageChan :: Chan FromServerMessage
+  , messageChan :: Chan SessionMessage
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
+  , sessionCapabilities :: ClientCapabilities
   }
 
   }
 
+class Monad m => HasReader r m where
+  ask :: m r
+  asks :: (r -> b) -> m b
+  asks f = f <$> ask
+
+instance Monad m => HasReader r (ParserStateReader a s r m) where
+  ask = lift $ lift Reader.ask
+
+instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
+  ask = lift $ lift Reader.ask
+
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
+  , curDiagnostics :: Map.Map Uri [Diagnostic]
+  , curTimeoutId :: Int
+  , overridingTimeout :: Bool
+  -- ^ The last received message from the server.
+  -- Used for providing exception information
+  , lastReceivedMessage :: Maybe FromServerMessage
   }
 
   }
 
-type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
-
-type SessionProcessor = ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO))
-
-runSession :: Chan FromServerMessage -> SessionProcessor () -> SessionContext -> SessionState -> Session a -> IO (a, SessionState)
-runSession chan preprocessor context state session = runReaderT (runStateT conduit state) context
-  where conduit = runConduit $ chanSource chan .| preprocessor .| runConduitParser (catchError session handler)
-        handler e@(Unexpected "ConduitParser.empty") = do
+class Monad m => HasState s m where
+  get :: m s
 
 
-          -- Horrible way to get last item in conduit:
-          -- Add a fake message so we can tell when to stop
-          liftIO $ writeChan chan (RspShutdown (ResponseMessage "EMPTY" IdRspNull Nothing Nothing))
-          x <- peek
-          case x of
-            Just x -> do
-              lastMsg <- skipToEnd x
-              name <- getParserName
-              liftIO $ throw (UnexpectedMessageException (T.unpack name) lastMsg)
-            Nothing -> throw e
+  put :: s -> m ()
 
 
-        handler e = throw e
+  modify :: (s -> s) -> m ()
+  modify f = get >>= put . f
 
 
-        skipToEnd x = do
-          y <- peek
-          case y of
-            Just (RspShutdown (ResponseMessage "EMPTY" IdRspNull Nothing Nothing)) -> return x
-            Just _ -> await >>= skipToEnd
-            Nothing -> return x
+  modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
+  modifyM f = get >>= f >>= put
 
 
-get :: Monad m => ParserStateReader a s r m s
+instance Monad m => HasState s (ParserStateReader a s r m) where
   get = lift State.get
   get = lift State.get
+  put = lift . State.put
 
 
-put :: Monad m => s -> ParserStateReader a s r m ()
+instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
+ where
+  get = lift State.get
   put = lift . State.put
 
   put = lift . State.put
 
-modify :: Monad m => (s -> s) -> ParserStateReader a s r m ()
-modify = lift . State.modify
+type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
 
 
-modifyM :: Monad m => (s -> m s) -> ParserStateReader a s r m ()
-modifyM f = do
-  old <- lift State.get
-  new <- lift $ lift $ lift $ f old
-  lift $ State.put new
+runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
+runSession context state session = runReaderT (runStateT conduit state) context
+  where
+    conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
 
-ask :: Monad m => ParserStateReader a s r m r
-ask = lift $ lift Reader.ask
+    handler (Unexpected "ConduitParser.empty") = do
+      lastMsg <- fromJust . lastReceivedMessage <$> get
+      name <- getParserName
+      liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
+
+    handler e = throw e
+
+    chanSource = do
+      msg <- liftIO $ readChan (messageChan context)
+      yield msg
+      chanSource
+
+    watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
+    watchdog = Conduit.awaitForever $ \msg -> do
+      curId <- curTimeoutId <$> get
+      case msg of
+        ServerMessage sMsg -> yield sMsg
+        TimeoutMessage tId -> when (curId == tId) $ throw Timeout
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
-                      -> (Handle -> Session ()) -- ^ Server listener
+                      -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> SessionConfig
-                      -> FilePath
+                      -> ClientCapabilities
+                      -> FilePath -- ^ Root directory
                       -> Session a
                       -> IO a
                       -> Session a
                       -> IO a
-runSessionWithHandles serverIn serverOut serverHandler config rootDir session = do
+runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
@@ -158,78 +194,138 @@ runSessionWithHandles serverIn serverOut serverHandler config rootDir session =
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
-  meaninglessChan <- newChan
   initRsp <- newEmptyMVar
 
   initRsp <- newEmptyMVar
 
-  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config
-      initState = SessionState (IdInt 0) mempty
+  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
+      initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
 
 
-  threadId <- forkIO $ void $ runSession meaninglessChan processor context initState (serverHandler serverOut)
-  (result, _) <- runSession messageChan processor context initState session
+  threadId <- forkIO $ void $ serverHandler serverOut context
+  (result, _) <- runSession context initState session
 
   killThread threadId
 
   return result
 
 
   killThread threadId
 
   return result
 
-  where processor :: SessionProcessor ()
-        processor = awaitForever $ \msg -> do
-          processTextChanges msg
+updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
+updateStateC = awaitForever $ \msg -> do
+  updateState msg
   yield msg
 
   yield msg
 
+updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
+updateState (NotPublishDiagnostics n) = do
+  let List diags = n ^. params . diagnostics
+      doc = n ^. params . uri
+  modify (\s ->
+    let newDiags = Map.insert doc diags (curDiagnostics s)
+      in s { curDiagnostics = newDiags })
+
+updateState (ReqApplyWorkspaceEdit r) = do
 
 
-processTextChanges :: FromServerMessage -> SessionProcessor ()
-processTextChanges (ReqApplyWorkspaceEdit r) = do
-  changeParams <- case r ^. params . edit . documentChanges of
-    Just (List cs) -> mapM applyTextDocumentEdit cs
+  allChangeParams <- case r ^. params . edit . documentChanges of
+    Just (List cs) -> do
+      mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
+      return $ map getParams cs
     Nothing -> case r ^. params . edit . changes of
     Nothing -> case r ^. params . edit . changes of
-      Just cs -> concat <$> mapM (uncurry applyChange) (HashMap.toList cs)
-      Nothing -> return []
+      Just cs -> do
+        mapM_ checkIfNeedsOpened (HashMap.keys cs)
+        return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
+      Nothing -> error "No changes!"
 
 
-  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) changeParams
-      mergedParams = map mergeParams groupedParams
+  modifyM $ \s -> do
+    newVFS <- liftIO $ changeFromServerVFS (vfs s) r
+    return $ s { vfs = newVFS }
 
 
-  ctx <- lift $ lift Reader.ask
+  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
+      mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
 
   -- TODO: Don't do this when replaying a session
-  forM_ mergedParams $ \p -> do
-    let h = serverIn ctx
-        msg = NotificationMessage "2.0" TextDocumentDidChange p
-    liftIO $ B.hPut h $ addHeader (encode msg)
+  forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
 
 
-  where applyTextDocumentEdit (TextDocumentEdit docId (List edits)) = do
-          oldVFS <- vfs <$> lift State.get
-          ctx <- lift $ lift Reader.ask
+  -- Update VFS to new document versions
+  let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
+      latestVersions = map ((^. textDocument) . last) sortedVersions
+      bumpedVersions = map (version . _Just +~ 1) latestVersions
 
 
+  forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
+    modify $ \s ->
+      let oldVFS = vfs s
+          update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
+          newVFS = Map.adjust update uri oldVFS
+      in s { vfs = newVFS }
+
+  where checkIfNeedsOpened uri = do
+          oldVFS <- vfs <$> get
+          ctx <- ask
 
           -- if its not open, open it
 
           -- if its not open, open it
-          unless ((docId ^. uri) `Map.member` oldVFS) $ do
-            let fp = fromJust $ uriToFilePath (docId ^. uri)
+          unless (uri `Map.member` oldVFS) $ do
+            let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
-            oldVFS <- vfs <$> lift State.get
-            newVFS <- liftIO $ openVFS oldVFS msg
-            lift $ State.modify (\s -> s { vfs = newVFS })
-
-          -- we might have updated it above
-          oldVFS <- vfs <$> lift State.get
+            modifyM $ \s -> do
+              newVFS <- liftIO $ openVFS (vfs s) msg
+              return $ s { vfs = newVFS }
 
 
+        getParams (TextDocumentEdit docId (List edits)) =
           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
-              params = DidChangeTextDocumentParams docId (List changeEvents)
-          newVFS <- liftIO $ changeVFS oldVFS (fmClientDidChangeTextDocumentNotification params)
-          lift $ State.modify (\s -> s { vfs = newVFS })
+            in DidChangeTextDocumentParams docId (List changeEvents)
 
 
-          return params
-
-        textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri) [0..]
+        textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
 
         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
 
 
         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
 
-        applyChange uri (List edits) = mapM applyTextDocumentEdit (textDocumentEdits uri (reverse edits))
+        getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
 
         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
 
         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
-processTextChanges _ = return ()
+updateState _ = return ()
+
+sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
+sendMessage msg = do
+  h <- serverIn <$> ask
+  logMsg LogClient msg
+  liftIO $ B.hPut h (addHeader $ encode msg)
+
+-- | Execute a block f that will throw a 'TimeoutException'
+-- after duration seconds. This will override the global timeout
+-- for waiting for messages to arrive defined in 'SessionConfig'.
+withTimeout :: Int -> Session a -> Session a
+withTimeout duration f = do
+  chan <- asks messageChan
+  timeoutId <- curTimeoutId <$> get
+  modify $ \s -> s { overridingTimeout = True }
+  liftIO $ forkIO $ do
+    threadDelay (duration * 1000000)
+    writeChan chan (TimeoutMessage timeoutId)
+  res <- f
+  modify $ \s -> s { curTimeoutId = timeoutId + 1,
+                     overridingTimeout = False
+                   }
+  return res
+
+data LogMsgType = LogServer | LogClient
+  deriving Eq
+
+-- | Logs the message if the config specified it
+logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
+       => LogMsgType -> a -> m ()
+logMsg t msg = do
+  shouldLog <- asks $ logMessages . config
+  shouldColor <- asks $ logColor . config
+  liftIO $ when shouldLog $ do
+    when shouldColor $ setSGR [SetColor Foreground Dull color]
+    putStrLn $ arrow ++ showPretty msg
+    when shouldColor $ setSGR [Reset]
+
+  where arrow
+          | t == LogServer  = "<-- "
+          | otherwise       = "--> "
+        color
+          | t == LogServer  = Magenta
+          | otherwise       = Cyan
+
+        showPretty = B.unpack . encodePretty