Add lspConfig option
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index ee6d871e070220fb8b866c7cb9512790731d0e2f..f0d410afed37c295a9ebdac061471056f1bb882b 100644 (file)
@@ -1,17 +1,29 @@
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
   ( Session
   , SessionConfig(..)
 
 module Language.Haskell.LSP.Test.Session
   ( Session
   , SessionConfig(..)
+  , defaultConfig
+  , SessionMessage(..)
   , SessionContext(..)
   , SessionState(..)
   , SessionContext(..)
   , SessionState(..)
-  , MonadSessionConfig(..)
   , runSessionWithHandles
   , get
   , put
   , modify
   , runSessionWithHandles
   , get
   , put
   , modify
-  , ask)
+  , modifyM
+  , ask
+  , asks
+  , sendMessage
+  , updateState
+  , withTimeout
+  , logMsg
+  , LogMsgType(..)
+  )
 
 where
 
 
 where
 
@@ -19,28 +31,35 @@ import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
+import Control.Monad.Fail
 import Control.Monad.IO.Class
 import Control.Monad.Except
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
 import Control.Monad.IO.Class
 import Control.Monad.Except
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
-import qualified Control.Monad.Trans.State as State (get, put, modify)
+import qualified Control.Monad.Trans.State as State (get, put)
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
-import Data.Conduit hiding (await)
-import Data.Conduit.Parser
+import Data.Aeson.Encode.Pretty
+import Data.Conduit as Conduit
+import Data.Conduit.Parser as Parser
 import Data.Default
 import Data.Foldable
 import Data.List
 import Data.Default
 import Data.Foldable
 import Data.List
+import qualified Data.Map as Map
 import qualified Data.Text as T
 import qualified Data.Text as T
+import qualified Data.Text.IO as T
 import qualified Data.HashMap.Strict as HashMap
 import qualified Data.HashMap.Strict as HashMap
+import Data.Maybe
+import Data.Function
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Messages
-import Language.Haskell.LSP.TH.ClientCapabilities
+import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
 import Language.Haskell.LSP.Types
+import Language.Haskell.LSP.Types.Lens hiding (error)
 import Language.Haskell.LSP.VFS
 import Language.Haskell.LSP.VFS
-import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
+import System.Console.ANSI
 import System.Directory
 import System.IO
 
 import System.Directory
 import System.IO
 
@@ -49,97 +68,125 @@ import System.IO
 -- You can send and receive messages to the server within 'Session' via 'getMessage',
 -- 'sendRequest' and 'sendNotification'.
 --
 -- You can send and receive messages to the server within 'Session' via 'getMessage',
 -- 'sendRequest' and 'sendNotification'.
 --
--- @
--- runSession \"path\/to\/root\/dir\" $ do
---   docItem <- getDocItem "Desktop/simple.hs" "haskell"
---   sendNotification TextDocumentDidOpen (DidOpenTextDocumentParams docItem)
---   diagnostics <- getMessage :: Session PublishDiagnosticsNotification
--- @
+
 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
 
 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
 
+instance MonadFail Session where
+  fail s = do
+    lastMsg <- fromJust . lastReceivedMessage <$> get
+    liftIO $ throw (UnexpectedMessage s lastMsg)
+
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
-  {
-    capabilities :: ClientCapabilities, -- ^ Specific capabilities the client should advertise.
-    timeout :: Int -- ^ Maximum time to wait for a request in seconds.
+  { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
+  , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
+  , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
+  , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
+  , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
   }
 
   }
 
-instance Default SessionConfig where
-  def = SessionConfig def 60
+-- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
+defaultConfig :: SessionConfig
+defaultConfig = SessionConfig 60 False True True Nothing
 
 
-class Monad m => MonadSessionConfig m where
-  sessionConfig :: m SessionConfig
+instance Default SessionConfig where
+  def = defaultConfig
 
 
-instance Monad m => MonadSessionConfig (StateT SessionState (ReaderT SessionContext m)) where
-  sessionConfig = config <$> lift Reader.ask
+data SessionMessage = ServerMessage FromServerMessage
+                    | TimeoutMessage Int
+  deriving Show
 
 data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
 
 data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
-  , messageChan :: Chan FromServerMessage
+  , messageChan :: Chan SessionMessage
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
+  , sessionCapabilities :: ClientCapabilities
   }
 
   }
 
+class Monad m => HasReader r m where
+  ask :: m r
+  asks :: (r -> b) -> m b
+  asks f = f <$> ask
+
+instance Monad m => HasReader r (ParserStateReader a s r m) where
+  ask = lift $ lift Reader.ask
+
+instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
+  ask = lift $ lift Reader.ask
+
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
+  , curDiagnostics :: Map.Map Uri [Diagnostic]
+  , curTimeoutId :: Int
+  , overridingTimeout :: Bool
+  -- ^ The last received message from the server.
+  -- Used for providing exception information
+  , lastReceivedMessage :: Maybe FromServerMessage
   }
 
   }
 
-type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
-
-type SessionProcessor = ConduitT FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO))
-
-runSession :: Chan FromServerMessage -> SessionProcessor () -> SessionContext -> SessionState -> Session a -> IO (a, SessionState)
-runSession chan preprocessor context state session = runReaderT (runStateT conduit state) context
-  where conduit = runConduit $ chanSource chan .| preprocessor .| runConduitParser (catchError session handler)
-        handler e@(Unexpected "ConduitParser.empty") = do
+class Monad m => HasState s m where
+  get :: m s
 
 
-          -- Horrible way to get last item in conduit:
-          -- Add a fake message so we can tell when to stop
-          liftIO $ writeChan chan (RspShutdown (ResponseMessage "EMPTY" IdRspNull Nothing Nothing))
-          x <- peek
-          case x of
-            Just x -> do
-              lastMsg <- skipToEnd x
-              name <- getParserName
-              liftIO $ throw (UnexpectedMessageException (T.unpack name) lastMsg)
-            Nothing -> throw e
+  put :: s -> m ()
 
 
-        handler e = throw e
+  modify :: (s -> s) -> m ()
+  modify f = get >>= put . f
 
 
-        skipToEnd x = do
-          y <- peek
-          case y of
-            Just (RspShutdown (ResponseMessage "EMPTY" IdRspNull Nothing Nothing)) -> return x
-            Just _ -> await >>= skipToEnd
-            Nothing -> return x
+  modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
+  modifyM f = get >>= f >>= put
 
 
-get :: Monad m => ParserStateReader a s r m s
+instance Monad m => HasState s (ParserStateReader a s r m) where
   get = lift State.get
   get = lift State.get
+  put = lift . State.put
 
 
-put :: Monad m => s -> ParserStateReader a s r m ()
+instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
+ where
+  get = lift State.get
   put = lift . State.put
 
   put = lift . State.put
 
-modify :: Monad m => (s -> s) -> ParserStateReader a s r m ()
-modify = lift . State.modify
+type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
 
 
-ask :: Monad m => ParserStateReader a s r m r
-ask = lift $ lift Reader.ask
+runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
+runSession context state session = runReaderT (runStateT conduit state) context
+  where
+    conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
+
+    handler (Unexpected "ConduitParser.empty") = do
+      lastMsg <- fromJust . lastReceivedMessage <$> get
+      name <- getParserName
+      liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
+
+    handler e = throw e
+
+    chanSource = do
+      msg <- liftIO $ readChan (messageChan context)
+      yield msg
+      chanSource
+
+    watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
+    watchdog = Conduit.awaitForever $ \msg -> do
+      curId <- curTimeoutId <$> get
+      case msg of
+        ServerMessage sMsg -> yield sMsg
+        TimeoutMessage tId -> when (curId == tId) $ throw Timeout
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
-                      -> (Handle -> Session ()) -- ^ Server listener
+                      -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> SessionConfig
-                      -> FilePath
+                      -> ClientCapabilities
+                      -> FilePath -- ^ Root directory
                       -> Session a
                       -> IO a
                       -> Session a
                       -> IO a
-runSessionWithHandles serverIn serverOut serverHandler config rootDir session = do
+runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
@@ -147,53 +194,138 @@ runSessionWithHandles serverIn serverOut serverHandler config rootDir session =
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
-  meaninglessChan <- newChan
   initRsp <- newEmptyMVar
 
   initRsp <- newEmptyMVar
 
-  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config
-      initState = SessionState (IdInt 0) mempty
+  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
+      initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
 
 
-  threadId <- forkIO $ void $ runSession meaninglessChan processor context initState (serverHandler serverOut)
-  (result, _) <- runSession messageChan processor context initState session
+  threadId <- forkIO $ void $ serverHandler serverOut context
+  (result, _) <- runSession context initState session
 
   killThread threadId
 
   return result
 
 
   killThread threadId
 
   return result
 
-  where processor :: SessionProcessor ()
-        processor = awaitForever $ \msg -> do
-          processTextChanges msg
+updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
+updateStateC = awaitForever $ \msg -> do
+  updateState msg
   yield msg
 
   yield msg
 
+updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
+updateState (NotPublishDiagnostics n) = do
+  let List diags = n ^. params . diagnostics
+      doc = n ^. params . uri
+  modify (\s ->
+    let newDiags = Map.insert doc diags (curDiagnostics s)
+      in s { curDiagnostics = newDiags })
 
 
-processTextChanges :: FromServerMessage -> SessionProcessor ()
-processTextChanges (ReqApplyWorkspaceEdit r) = do
-  List changeParams <- case r ^. params . edit . documentChanges of
-    Just cs -> mapM applyTextDocumentEdit cs
+updateState (ReqApplyWorkspaceEdit r) = do
+
+  allChangeParams <- case r ^. params . edit . documentChanges of
+    Just (List cs) -> do
+      mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
+      return $ map getParams cs
     Nothing -> case r ^. params . edit . changes of
     Nothing -> case r ^. params . edit . changes of
-      Just cs -> mapM (uncurry applyTextEdit) (List (HashMap.toList cs))
-      Nothing -> return (List [])
+      Just cs -> do
+        mapM_ checkIfNeedsOpened (HashMap.keys cs)
+        return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
+      Nothing -> error "No changes!"
+
+  modifyM $ \s -> do
+    newVFS <- liftIO $ changeFromServerVFS (vfs s) r
+    return $ s { vfs = newVFS }
 
 
-  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) changeParams
+  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
-  forM_ mergedParams $ \p -> do
-    h <- serverIn <$> lift (lift Reader.ask)
-    let msg = NotificationMessage "2.0" TextDocumentDidChange p
-    liftIO $ B.hPut h $ addHeader (encode msg)
-
-  where applyTextDocumentEdit (TextDocumentEdit docId (List edits)) = do
-          oldVFS <- vfs <$> lift State.get
+  forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
+
+  -- Update VFS to new document versions
+  let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
+      latestVersions = map ((^. textDocument) . last) sortedVersions
+      bumpedVersions = map (version . _Just +~ 1) latestVersions
+
+  forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
+    modify $ \s ->
+      let oldVFS = vfs s
+          update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
+          newVFS = Map.adjust update uri oldVFS
+      in s { vfs = newVFS }
+
+  where checkIfNeedsOpened uri = do
+          oldVFS <- vfs <$> get
+          ctx <- ask
+
+          -- if its not open, open it
+          unless (uri `Map.member` oldVFS) $ do
+            let fp = fromJust $ uriToFilePath uri
+            contents <- liftIO $ T.readFile fp
+            let item = TextDocumentItem (filePathToUri fp) "" 0 contents
+                msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
+            liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
+
+            modifyM $ \s -> do
+              newVFS <- liftIO $ openVFS (vfs s) msg
+              return $ s { vfs = newVFS }
+
+        getParams (TextDocumentEdit docId (List edits)) =
           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
-              params = DidChangeTextDocumentParams docId (List changeEvents)
-          newVFS <- liftIO $ changeVFS oldVFS (fmClientDidChangeTextDocumentNotification params)
-          lift $ State.modify (\s -> s { vfs = newVFS })
-          return params
+            in DidChangeTextDocumentParams docId (List changeEvents)
+
+        textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
 
 
-        applyTextEdit uri edits = applyTextDocumentEdit (TextDocumentEdit (VersionedTextDocumentIdentifier uri 0) edits)
+        textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
+
+        getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
 
         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
 
         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
-processTextChanges _ = return ()
+updateState _ = return ()
+
+sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
+sendMessage msg = do
+  h <- serverIn <$> ask
+  logMsg LogClient msg
+  liftIO $ B.hPut h (addHeader $ encode msg)
+
+-- | Execute a block f that will throw a 'TimeoutException'
+-- after duration seconds. This will override the global timeout
+-- for waiting for messages to arrive defined in 'SessionConfig'.
+withTimeout :: Int -> Session a -> Session a
+withTimeout duration f = do
+  chan <- asks messageChan
+  timeoutId <- curTimeoutId <$> get
+  modify $ \s -> s { overridingTimeout = True }
+  liftIO $ forkIO $ do
+    threadDelay (duration * 1000000)
+    writeChan chan (TimeoutMessage timeoutId)
+  res <- f
+  modify $ \s -> s { curTimeoutId = timeoutId + 1,
+                     overridingTimeout = False
+                   }
+  return res
+
+data LogMsgType = LogServer | LogClient
+  deriving Eq
+
+-- | Logs the message if the config specified it
+logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
+       => LogMsgType -> a -> m ()
+logMsg t msg = do
+  shouldLog <- asks $ logMessages . config
+  shouldColor <- asks $ logColor . config
+  liftIO $ when shouldLog $ do
+    when shouldColor $ setSGR [SetColor Foreground Dull color]
+    putStrLn $ arrow ++ showPretty msg
+    when shouldColor $ setSGR [Reset]
+
+  where arrow
+          | t == LogServer  = "<-- "
+          | otherwise       = "--> "
+        color
+          | t == LogServer  = Magenta
+          | otherwise       = Cyan
+
+        showPretty = B.unpack . encodePretty