Add ReplayOutOfOrder exception and change function signature
[opengl.git] / src / Language / Haskell / LSP / Test / Replay.hs
index dfa364b44c3600b4682eca2f94699f2fea16fad3..ad26858ee39632d4a5e8260a83576ef0a65a93b6 100644 (file)
@@ -9,13 +9,15 @@ import           Prelude hiding (id)
 import           Control.Concurrent
 import           Control.Monad.IO.Class
 import qualified Data.ByteString.Lazy.Char8    as B
+import qualified Data.Text                     as T
 import           Language.Haskell.LSP.Capture
 import           Language.Haskell.LSP.Messages
-import           Language.Haskell.LSP.Types hiding (error)
+import           Language.Haskell.LSP.Types as LSP hiding (error)
 import           Data.Aeson
+import           Data.Default
 import           Data.List
 import           Data.Maybe
-import           Control.Lens
+import           Control.Lens hiding (List)
 import           Control.Monad
 import           System.IO
 import           System.FilePath
@@ -23,15 +25,16 @@ import           Language.Haskell.LSP.Test
 import           Language.Haskell.LSP.Test.Files
 import           Language.Haskell.LSP.Test.Decoding
 import           Language.Haskell.LSP.Test.Messages
+import           Language.Haskell.LSP.Test.Server
 
 
 -- | Replays a captured client output and 
 -- makes sure it matches up with an expected response.
 -- The session directory should have a captured session file in it
 -- named "session.log".
-replaySession :: FilePath -- ^ The filepath to the server executable.
+replaySession :: String -- ^ The command to run the server.
               -> FilePath -- ^ The recorded session directory.
-              -> IO Bool
+              -> IO ()
 replaySession serverExe sessionDir = do
 
   entries <- B.lines <$> B.readFile (sessionDir </> "session.log")
@@ -39,7 +42,9 @@ replaySession serverExe sessionDir = do
   -- decode session
   let unswappedEvents = map (fromJust . decode) entries
 
-  events <- swapFiles sessionDir unswappedEvents
+  withServer serverExe $ \serverIn serverOut pid -> do
+
+    events <- swapCommands pid <$> swapFiles sessionDir unswappedEvents
 
     let clientEvents = filter isClientMsg events
         serverEvents = filter isServerMsg events
@@ -49,17 +54,18 @@ replaySession serverExe sessionDir = do
 
     reqSema <- newEmptyMVar
     rspSema <- newEmptyMVar
-  passVar <- newEmptyMVar :: IO (MVar Bool)
-
-  threadId <- forkIO $
-    runSessionWithHandler (listenServer serverMsgs requestMap reqSema rspSema passVar)
-                          serverExe
+    passSema <- newEmptyMVar
+    mainThread <- myThreadId
+
+    sessionThread <- liftIO $ forkIO $
+      runSessionWithHandles serverIn
+                            serverOut
+                            (listenServer serverMsgs requestMap reqSema rspSema passSema mainThread)
+                            def
                             sessionDir
                             (sendMessages clientMsgs reqSema rspSema)
-
-  result <- takeMVar passVar
-  killThread threadId
-  return result
+    takeMVar passSema
+    killThread sessionThread
 
   where
     isClientMsg (FromClient _ _) = True
@@ -117,26 +123,23 @@ isNotification (NotShowMessage             _) = True
 isNotification (NotCancelRequestFromServer _) = True
 isNotification _                              = False
 
-listenServer :: [FromServerMessage] -> RequestMap -> MVar LspId -> MVar LspIdRsp -> MVar Bool -> Handle -> Session ()
-listenServer [] _ _ _ passVar _ = liftIO $ putMVar passVar True
-listenServer expectedMsgs reqMap reqSema rspSema passVar serverOut  = do
+listenServer :: [FromServerMessage] -> RequestMap -> MVar LspId -> MVar LspIdRsp -> MVar () -> ThreadId -> Handle -> Session ()
+listenServer [] _ _ _ passSema _ _ = liftIO $ putMVar passSema ()
+listenServer expectedMsgs reqMap reqSema rspSema passSema mainThreadId serverOut  = do
+
   msgBytes <- liftIO $ getNextMessage serverOut
   let msg = decodeFromServerMsg reqMap msgBytes
 
   handleServerMessage request response notification msg
 
   if shouldSkip msg
-    then listenServer expectedMsgs reqMap reqSema rspSema passVar serverOut
+    then listenServer expectedMsgs reqMap reqSema rspSema passSema mainThreadId serverOut
     else if inRightOrder msg expectedMsgs
-      then listenServer (delete msg expectedMsgs) reqMap reqSema rspSema passVar serverOut
-      else liftIO $ do
-        putStrLn "Out of order"
-        putStrLn "Got:"
-        print msg
-        putStrLn "Expected one of:"
-        mapM_ print $ takeWhile (not . isNotification) expectedMsgs
-        print $ head $ dropWhile isNotification expectedMsgs
-        putMVar passVar False
+      then listenServer (delete msg expectedMsgs) reqMap reqSema rspSema passSema mainThreadId serverOut
+      else let expectedMsgs = takeWhile (not . isNotification) expectedMsgs
+                ++ [head $ dropWhile isNotification expectedMsgs]
+               exc = ReplayOutOfOrderException msg expectedMsgs
+            in liftIO $ throwTo mainThreadId exc
 
   where
   response :: ResponseMessage a -> Session ()
@@ -186,3 +189,27 @@ shouldSkip (NotLogMessage  _) = True
 shouldSkip (NotShowMessage _) = True
 shouldSkip (ReqShowMessage _) = True
 shouldSkip _                  = False
+
+-- | Swaps out any commands uniqued with process IDs to match the specified process ID
+swapCommands :: Int -> [Event] -> [Event]
+swapCommands _ [] = []
+
+swapCommands pid (FromClient t (ReqExecuteCommand req):xs) =  FromClient t (ReqExecuteCommand swapped):swapCommands pid xs
+  where swapped = params . command .~ newCmd $ req
+        newCmd = swapPid pid (req ^. params . command)
+
+swapCommands pid (FromServer t (RspInitialize rsp):xs) = FromServer t (RspInitialize swapped):swapCommands pid xs
+  where swapped = case newCommands of
+          Just cmds -> result . _Just . LSP.capabilities . executeCommandProvider . _Just . commands .~ cmds $ rsp
+          Nothing -> rsp
+        oldCommands = rsp ^? result . _Just . LSP.capabilities . executeCommandProvider . _Just . commands
+        newCommands = fmap (fmap (swapPid pid)) oldCommands
+
+swapCommands pid (x:xs) = x:swapCommands pid xs
+
+hasPid :: T.Text -> Bool
+hasPid = (>= 2) . T.length . T.filter (':' ==)
+swapPid :: Int -> T.Text -> T.Text
+swapPid pid t
+  | hasPid t = T.append (T.pack $ show pid) $ T.dropWhile (/= ':') t
+  | otherwise = t
\ No newline at end of file