Track changes to haskell-lsp
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index 39a3ed29bd4b44d406d4cc288eca49fa3540faaf..3426bcce47c66e105ddb235b13d308a92dcafe45 100644 (file)
@@ -1,10 +1,13 @@
+{-# LANGUAGE CPP               #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
 
 module Language.Haskell.LSP.Test.Session
-  ( Session
+  ( Session(..)
   , SessionConfig(..)
   , defaultConfig
   , SessionMessage(..)
   , SessionConfig(..)
   , defaultConfig
   , SessionMessage(..)
@@ -26,16 +29,20 @@ module Language.Haskell.LSP.Test.Session
 
 where
 
 
 where
 
+import Control.Applicative
 import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
 import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
+#if __GLASGOW_HASKELL__ == 806
+import Control.Monad.Fail
+#endif
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
-import qualified Control.Monad.Trans.State as State (get, put)
+import qualified Control.Monad.Trans.State as State
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import Data.Aeson.Encode.Pretty
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import Data.Aeson.Encode.Pretty
@@ -52,37 +59,50 @@ import Data.Maybe
 import Data.Function
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Types.Capabilities
 import Data.Function
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Types.Capabilities
-import Language.Haskell.LSP.Types hiding (error)
+import Language.Haskell.LSP.Types
+import Language.Haskell.LSP.Types.Lens hiding (error)
 import Language.Haskell.LSP.VFS
 import Language.Haskell.LSP.VFS
+import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
 import System.IO
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
 import System.IO
+import System.Process (ProcessHandle())
+import System.Timeout
 
 -- | A session representing one instance of launching and connecting to a server.
 --
 
 -- | A session representing one instance of launching and connecting to a server.
 --
--- You can send and receive messages to the server within 'Session' via 'getMessage',
--- 'sendRequest' and 'sendNotification'.
---
--- @
--- runSession \"path\/to\/root\/dir\" $ do
---   docItem <- getDocItem "Desktop/simple.hs" "haskell"
---   sendNotification TextDocumentDidOpen (DidOpenTextDocumentParams docItem)
---   diagnostics <- getMessage :: Session PublishDiagnosticsNotification
--- @
-type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
+-- You can send and receive messages to the server within 'Session' via
+-- 'Language.Haskell.LSP.Test.message',
+-- 'Language.Haskell.LSP.Test.sendRequest' and
+-- 'Language.Haskell.LSP.Test.sendNotification'.
+
+newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
+  deriving (Functor, Applicative, Monad, MonadIO, Alternative)
+
+#if __GLASGOW_HASKELL__ >= 806
+instance MonadFail Session where
+  fail s = do
+    lastMsg <- fromJust . lastReceivedMessage <$> get
+    liftIO $ throw (UnexpectedMessage s lastMsg)
+#endif
 
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
 
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
-  , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to True.
+  , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
+  , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
+  -- ^ Whether or not to ignore 'ShowMessageNotification' and 'LogMessageNotification', defaults to False.
+  -- @since 0.9.0.0
+  , ignoreLogNotifications :: Bool
   }
 
   }
 
+-- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
 defaultConfig :: SessionConfig
-defaultConfig = SessionConfig 60 False True True
+defaultConfig = SessionConfig 60 False False True Nothing False
 
 instance Default SessionConfig where
   def = defaultConfig
 
 instance Default SessionConfig where
   def = defaultConfig
@@ -107,17 +127,17 @@ class Monad m => HasReader r m where
   asks :: (r -> b) -> m b
   asks f = f <$> ask
 
   asks :: (r -> b) -> m b
   asks f = f <$> ask
 
-instance Monad m => HasReader r (ParserStateReader a s r m) where
-  ask = lift $ lift Reader.ask
+instance HasReader SessionContext Session where
+  ask  = Session (lift $ lift Reader.ask)
 
 
-instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
+instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
   ask = lift $ lift Reader.ask
 
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
-  , curDiagnostics :: Map.Map Uri [Diagnostic]
+  , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
   , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
   , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
@@ -136,19 +156,22 @@ class Monad m => HasState s m where
   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
   modifyM f = get >>= f >>= put
 
   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
   modifyM f = get >>= f >>= put
 
-instance Monad m => HasState s (ParserStateReader a s r m) where
+instance HasState SessionState Session where
+  get = Session (lift State.get)
+  put = Session . lift . State.put
+
+instance Monad m => HasState s (ConduitM a b (StateT s m))
+ where
   get = lift State.get
   put = lift . State.put
 
   get = lift State.get
   put = lift . State.put
 
-instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
+instance Monad m => HasState s (ConduitParser a (StateT s m))
  where
   get = lift State.get
   put = lift . State.put
 
  where
   get = lift State.get
   put = lift . State.put
 
-type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
-
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
-runSession context state session = runReaderT (runStateT conduit state) context
+runSession context state (Session session) = runReaderT (runStateT conduit state) context
   where
     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
   where
     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
@@ -161,44 +184,65 @@ runSession context state session = runReaderT (runStateT conduit state) context
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
+      unless (ignoreLogNotifications (config context) && isLogNotification msg) $
         yield msg
       chanSource
 
         yield msg
       chanSource
 
+    isLogNotification (ServerMessage (NotShowMessage _)) = True
+    isLogNotification (ServerMessage (NotLogMessage _)) = True
+    isLogNotification _ = False
+
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
       curId <- curTimeoutId <$> get
       case msg of
         ServerMessage sMsg -> yield sMsg
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
       curId <- curTimeoutId <$> get
       case msg of
         ServerMessage sMsg -> yield sMsg
-        TimeoutMessage tId -> when (curId == tId) $ throw Timeout
+        TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
+                      -> ProcessHandle -- ^ Server process
                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> ClientCapabilities
                       -> FilePath -- ^ Root directory
                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> ClientCapabilities
                       -> FilePath -- ^ Root directory
+                      -> Session () -- ^ To exit the Server properly
                       -> Session a
                       -> IO a
                       -> Session a
                       -> IO a
-runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
+runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   hSetBuffering serverOut NoBuffering
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   hSetBuffering serverOut NoBuffering
+  -- This is required to make sure that we don’t get any
+  -- newline conversion or weird encoding issues.
+  hSetBinaryMode serverIn True
+  hSetBinaryMode serverOut True
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
   initRsp <- newEmptyMVar
 
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
   initRsp <- newEmptyMVar
 
-  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
-      initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
-
-  threadId <- forkIO $ void $ serverHandler serverOut context
-  (result, _) <- runSession context initState session
-
-  killThread threadId
+  mainThreadId <- myThreadId
 
 
+  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
+      initState vfs = SessionState (IdInt 0) vfs
+                                       mempty 0 False Nothing
+      runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
+
+      errorHandler = throwTo mainThreadId :: SessionException -> IO()
+      serverListenerLauncher =
+        forkIO $ catch (serverHandler serverOut context) errorHandler
+      server = (Just serverIn, Just serverOut, Nothing, serverProc)
+      serverAndListenerFinalizer tid =
+        finally (timeout (messageTimeout config * 1000000)
+                         (runSession' exitServer))
+                (cleanupProcess server >> killThread tid)
+
+  (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
+                         (const $ runSession' session)
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
@@ -206,12 +250,13 @@ updateStateC = awaitForever $ \msg -> do
   updateState msg
   yield msg
 
   updateState msg
   yield msg
 
-updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
+updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
+            => FromServerMessage -> m ()
 updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
   modify (\s ->
 updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
   modify (\s ->
-    let newDiags = Map.insert doc diags (curDiagnostics s)
+    let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
       in s { curDiagnostics = newDiags })
 
 updateState (ReqApplyWorkspaceEdit r) = do
       in s { curDiagnostics = newDiags })
 
 updateState (ReqApplyWorkspaceEdit r) = do
@@ -230,7 +275,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
-  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
+  let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
@@ -244,8 +289,8 @@ updateState (ReqApplyWorkspaceEdit r) = do
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
-          update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
-          newVFS = Map.adjust update uri oldVFS
+          update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
+          newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
@@ -253,7 +298,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
           ctx <- ask
 
           -- if its not open, open it
           ctx <- ask
 
           -- if its not open, open it
-          unless (uri `Map.member` oldVFS) $ do
+          unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
@@ -261,7 +306,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
-              newVFS <- liftIO $ openVFS (vfs s) msg
+              let (newVFS,_) = openVFS (vfs s) msg
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
@@ -285,7 +330,7 @@ sendMessage msg = do
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
--- | Execute a block f that will throw a 'TimeoutException'
+-- | Execute a block f that will throw a 'Timeout' exception
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.
 withTimeout :: Int -> Session a -> Session a
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.
 withTimeout :: Int -> Session a -> Session a
@@ -302,14 +347,6 @@ withTimeout duration f = do
                    }
   return res
 
                    }
   return res
 
--- logClientMsg :: (MonadIO m, HasReader SessionContext m)
---              => FromClientMessage -> m ()
--- logClientMsg = logMsg True
-
--- logServerMsg :: (MonadIO m, HasReader SessionContext m)
---              => FromServerMessage -> m ()
--- logServerMsg = logMsg False
-
 data LogMsgType = LogServer | LogClient
   deriving Eq
 
 data LogMsgType = LogServer | LogClient
   deriving Eq
 
@@ -331,6 +368,6 @@ logMsg t msg = do
           | t == LogServer  = Magenta
           | otherwise       = Cyan
 
           | t == LogServer  = Magenta
           | otherwise       = Cyan
 
-
-showPretty :: ToJSON a => a -> String
         showPretty = B.unpack . encodePretty
         showPretty = B.unpack . encodePretty
+
+