Track changes to haskell-lsp
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index 291a3e0bcaa990327988f2bcbd9df6554a63f50f..3426bcce47c66e105ddb235b13d308a92dcafe45 100644 (file)
@@ -1,12 +1,13 @@
 {-# LANGUAGE CPP               #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE CPP               #-}
 {-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
 {-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
 {-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
-  ( Session
+  ( Session(..)
   , SessionConfig(..)
   , defaultConfig
   , SessionMessage(..)
   , SessionConfig(..)
   , defaultConfig
   , SessionMessage(..)
@@ -28,19 +29,20 @@ module Language.Haskell.LSP.Test.Session
 
 where
 
 
 where
 
+import Control.Applicative
 import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
 import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
-#if __GLASGOW_HASKELL__ >= 806
+#if __GLASGOW_HASKELL__ == 806
 import Control.Monad.Fail
 #endif
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
 import Control.Monad.Fail
 #endif
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
-import qualified Control.Monad.Trans.State as State (get, put)
+import qualified Control.Monad.Trans.State as State
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import Data.Aeson.Encode.Pretty
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import Data.Aeson.Encode.Pretty
@@ -60,11 +62,14 @@ import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
 import Language.Haskell.LSP.Types.Lens hiding (error)
 import Language.Haskell.LSP.VFS
 import Language.Haskell.LSP.Types
 import Language.Haskell.LSP.Types.Lens hiding (error)
 import Language.Haskell.LSP.VFS
+import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
 import System.IO
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
 import System.IO
+import System.Process (ProcessHandle())
+import System.Timeout
 
 -- | A session representing one instance of launching and connecting to a server.
 --
 
 -- | A session representing one instance of launching and connecting to a server.
 --
@@ -73,7 +78,8 @@ import System.IO
 -- 'Language.Haskell.LSP.Test.sendRequest' and
 -- 'Language.Haskell.LSP.Test.sendNotification'.
 
 -- 'Language.Haskell.LSP.Test.sendRequest' and
 -- 'Language.Haskell.LSP.Test.sendNotification'.
 
-type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
+newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
+  deriving (Functor, Applicative, Monad, MonadIO, Alternative)
 
 #if __GLASGOW_HASKELL__ >= 806
 instance MonadFail Session where
 
 #if __GLASGOW_HASKELL__ >= 806
 instance MonadFail Session where
@@ -89,11 +95,14 @@ data SessionConfig = SessionConfig
   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
+  -- ^ Whether or not to ignore 'ShowMessageNotification' and 'LogMessageNotification', defaults to False.
+  -- @since 0.9.0.0
+  , ignoreLogNotifications :: Bool
   }
 
 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
   }
 
 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
-defaultConfig = SessionConfig 60 False False True Nothing
+defaultConfig = SessionConfig 60 False False True Nothing False
 
 instance Default SessionConfig where
   def = defaultConfig
 
 instance Default SessionConfig where
   def = defaultConfig
@@ -118,17 +127,17 @@ class Monad m => HasReader r m where
   asks :: (r -> b) -> m b
   asks f = f <$> ask
 
   asks :: (r -> b) -> m b
   asks f = f <$> ask
 
-instance Monad m => HasReader r (ParserStateReader a s r m) where
-  ask = lift $ lift Reader.ask
+instance HasReader SessionContext Session where
+  ask  = Session (lift $ lift Reader.ask)
 
 
-instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
+instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
   ask = lift $ lift Reader.ask
 
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
-  , curDiagnostics :: Map.Map Uri [Diagnostic]
+  , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
   , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
   , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
@@ -147,19 +156,22 @@ class Monad m => HasState s m where
   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
   modifyM f = get >>= f >>= put
 
   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
   modifyM f = get >>= f >>= put
 
-instance Monad m => HasState s (ParserStateReader a s r m) where
+instance HasState SessionState Session where
+  get = Session (lift State.get)
+  put = Session . lift . State.put
+
+instance Monad m => HasState s (ConduitM a b (StateT s m))
+ where
   get = lift State.get
   put = lift . State.put
 
   get = lift State.get
   put = lift . State.put
 
-instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
+instance Monad m => HasState s (ConduitParser a (StateT s m))
  where
   get = lift State.get
   put = lift . State.put
 
  where
   get = lift State.get
   put = lift . State.put
 
-type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
-
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
-runSession context state session = runReaderT (runStateT conduit state) context
+runSession context state (Session session) = runReaderT (runStateT conduit state) context
   where
     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
   where
     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
@@ -172,31 +184,42 @@ runSession context state session = runReaderT (runStateT conduit state) context
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
+      unless (ignoreLogNotifications (config context) && isLogNotification msg) $
         yield msg
       chanSource
 
         yield msg
       chanSource
 
+    isLogNotification (ServerMessage (NotShowMessage _)) = True
+    isLogNotification (ServerMessage (NotLogMessage _)) = True
+    isLogNotification _ = False
+
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
       curId <- curTimeoutId <$> get
       case msg of
         ServerMessage sMsg -> yield sMsg
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
       curId <- curTimeoutId <$> get
       case msg of
         ServerMessage sMsg -> yield sMsg
-        TimeoutMessage tId -> when (curId == tId) $ throw Timeout
+        TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
+                      -> ProcessHandle -- ^ Server process
                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> ClientCapabilities
                       -> FilePath -- ^ Root directory
                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> ClientCapabilities
                       -> FilePath -- ^ Root directory
+                      -> Session () -- ^ To exit the Server properly
                       -> Session a
                       -> IO a
                       -> Session a
                       -> IO a
-runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
+runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   hSetBuffering serverOut NoBuffering
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   hSetBuffering serverOut NoBuffering
+  -- This is required to make sure that we don’t get any
+  -- newline conversion or weird encoding issues.
+  hSetBinaryMode serverIn True
+  hSetBinaryMode serverOut True
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
@@ -205,12 +228,21 @@ runSessionWithHandles serverIn serverOut serverHandler config caps rootDir sessi
   mainThreadId <- myThreadId
 
   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
   mainThreadId <- myThreadId
 
   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
-      initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
-      launchServerHandler = forkIO $ catch (serverHandler serverOut context)
-                                           (throwTo mainThreadId :: SessionException -> IO ())
-  (result, _) <- bracket launchServerHandler killThread $
-    const $ runSession context initState session
-
+      initState vfs = SessionState (IdInt 0) vfs
+                                       mempty 0 False Nothing
+      runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
+
+      errorHandler = throwTo mainThreadId :: SessionException -> IO()
+      serverListenerLauncher =
+        forkIO $ catch (serverHandler serverOut context) errorHandler
+      server = (Just serverIn, Just serverOut, Nothing, serverProc)
+      serverAndListenerFinalizer tid =
+        finally (timeout (messageTimeout config * 1000000)
+                         (runSession' exitServer))
+                (cleanupProcess server >> killThread tid)
+
+  (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
+                         (const $ runSession' session)
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
@@ -218,12 +250,13 @@ updateStateC = awaitForever $ \msg -> do
   updateState msg
   yield msg
 
   updateState msg
   yield msg
 
-updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
+updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
+            => FromServerMessage -> m ()
 updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
   modify (\s ->
 updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
   modify (\s ->
-    let newDiags = Map.insert doc diags (curDiagnostics s)
+    let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
       in s { curDiagnostics = newDiags })
 
 updateState (ReqApplyWorkspaceEdit r) = do
       in s { curDiagnostics = newDiags })
 
 updateState (ReqApplyWorkspaceEdit r) = do
@@ -242,7 +275,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
-  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
+  let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
@@ -256,8 +289,8 @@ updateState (ReqApplyWorkspaceEdit r) = do
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
-          update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
-          newVFS = Map.adjust update uri oldVFS
+          update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
+          newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
@@ -265,7 +298,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
           ctx <- ask
 
           -- if its not open, open it
           ctx <- ask
 
           -- if its not open, open it
-          unless (uri `Map.member` oldVFS) $ do
+          unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
@@ -273,7 +306,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
-              newVFS <- liftIO $ openVFS (vfs s) msg
+              let (newVFS,_) = openVFS (vfs s) msg
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
@@ -337,3 +370,4 @@ logMsg t msg = do
 
         showPretty = B.unpack . encodePretty
 
 
         showPretty = B.unpack . encodePretty
 
+