module Vulkan.Utils.ShaderQQ
  ( glsl
  , comp
  , frag
  , geom
  , tesc
  , tese
  , vert
  , GLSLError
  , GLSLWarning
  , compileShaderQ
  , compileShader
  , processValidatorMessages
  ) where

import           Control.Monad.IO.Class
import           Data.ByteString                ( ByteString )
import qualified Data.ByteString               as BS
import qualified Data.ByteString.Lazy.Char8    as BSL
import           Data.Char
import           Data.FileEmbed
import           Data.List.Extra
import           Language.Haskell.TH
import           Language.Haskell.TH.Quote
import           System.Exit
import           System.FilePath
import           System.IO.Temp
import           System.Process.Typed
import           Vulkan.Utils.Internal          ( badQQ )
import           Vulkan.Utils.ShaderQQ.Interpolate

-- $setup
-- >>> :set -XQuasiQuotes

-- | 'glsl' is a QuasiQuoter which produces GLSL source code with @#line@
-- directives inserted so that error locations point to the correct location in
-- the Haskell source file. It also permits basic string interpolation.
--
-- - Interpolated variables are prefixed with @$@
-- - They can optionally be surrounded with braces like @${foo}@
-- - Interpolated variables are converted to strings with 'show'
-- - To escape a @$@ use @\\$@
--
-- It is intended to be used in concert with 'compileShaderQ' like so
--
-- @
-- myConstant = 3.141 -- Note that this will have to be in a different module
-- myFragmentShader = $(compileShaderQ "frag" [glsl|
--   #version 450
--   const float myConstant = ${myConstant};
--   main (){
--   }
-- |])
-- @
--
-- An explicit example (@<interactive>@ is from doctest):
--
-- >>> let version = 450 :: Int in [glsl|#version $version|]
-- "#version 450\n#extension GL_GOOGLE_cpp_style_line_directive : enable\n#line 46 \"<interactive>\"\n"
--
-- Note that line number will be thrown off if any of the interpolated
-- variables contain newlines.
glsl :: QuasiQuoter
glsl :: QuasiQuoter
glsl = (String -> QuasiQuoter
badQQ "glsl")
  { quoteExp :: String -> Q Exp
quoteExp = \s :: String
s -> do
                 Loc
loc <- Q Loc
location
                 -- Insert the directive here, `compileShaderQ` will insert
                 -- another one, but it's before this one, so who cares.
                 let codeWithLineDirective :: String
codeWithLineDirective = String -> Loc -> String
insertLineDirective String
s Loc
loc
                 String -> Q Exp
interpExp String
codeWithLineDirective
  }

-- | QuasiQuoter for creating a compute shader.
--
-- Equivalent to calling @$(compileShaderQ "comp" [glsl|...|])@ without
-- interpolation support.
comp :: QuasiQuoter
comp :: QuasiQuoter
comp = String -> QuasiQuoter
shaderQQ "comp"

-- | QuasiQuoter for creating a fragment shader.
--
-- Equivalent to calling @$(compileShaderQ "frag" [glsl|...|])@ without
-- interpolation support.
frag :: QuasiQuoter
frag :: QuasiQuoter
frag = String -> QuasiQuoter
shaderQQ "frag"

-- | QuasiQuoter for creating a geometry shader.
--
-- Equivalent to calling @$(compileShaderQ "geom" [glsl|...|])@ without
-- interpolation support.
geom :: QuasiQuoter
geom :: QuasiQuoter
geom = String -> QuasiQuoter
shaderQQ "geom"

-- | QuasiQuoter for creating a tessellation control shader.
--
-- Equivalent to calling @$(compileShaderQ "tesc" [glsl|...|])@ without
-- interpolation support.
tesc :: QuasiQuoter
tesc :: QuasiQuoter
tesc = String -> QuasiQuoter
shaderQQ "tesc"

-- | QuasiQuoter for creating a tessellation evaluation shader.
--
-- Equivalent to calling @$(compileShaderQ "tese" [glsl|...|])@ without
-- interpolation support.
tese :: QuasiQuoter
tese :: QuasiQuoter
tese = String -> QuasiQuoter
shaderQQ "tese"

-- | QuasiQuoter for creating a vertex shader.
--
-- Equivalent to calling @$(compileShaderQ "vert" [glsl|...|])@ without
-- interpolation support.
vert :: QuasiQuoter
vert :: QuasiQuoter
vert = String -> QuasiQuoter
shaderQQ "vert"

shaderQQ :: String -> QuasiQuoter
shaderQQ :: String -> QuasiQuoter
shaderQQ stage :: String
stage = (String -> QuasiQuoter
badQQ String
stage) { quoteExp :: String -> Q Exp
quoteExp = Maybe String -> String -> String -> Q Exp
compileShaderQ Maybe String
forall a. Maybe a
Nothing String
stage }

-- * Utilities

-- | Compile a glsl shader to spir-v using glslangValidator.
--
-- Messages are converted to GHC warnings or errors depending on compilation success.
compileShaderQ
  :: Maybe String
  -- ^ Argument to pass to `--target-env`
  -> String
  -- ^ stage
  -> String
  -- ^ glsl code
  -> Q Exp
  -- ^ Spir-V bytecode
compileShaderQ :: Maybe String -> String -> String -> Q Exp
compileShaderQ targetEnv :: Maybe String
targetEnv stage :: String
stage code :: String
code = do
  Loc
loc                <- Q Loc
location
  (warnings :: [String]
warnings, result :: Either [String] ByteString
result) <- Maybe Loc
-> Maybe String
-> String
-> String
-> Q ([String], Either [String] ByteString)
forall (m :: * -> *).
MonadIO m =>
Maybe Loc
-> Maybe String
-> String
-> String
-> m ([String], Either [String] ByteString)
compileShader (Loc -> Maybe Loc
forall a. a -> Maybe a
Just Loc
loc) Maybe String
targetEnv String
stage String
code
  case [String]
warnings of
    []    -> () -> Q ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    _some :: [String]
_some -> String -> Q ()
reportWarning (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
prepare [String]
warnings

  ByteString
bs <- case Either [String] ByteString
result of
    Left []     -> String -> Q ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "glslangValidator failed with no errors"
    Left errors :: [String]
errors -> do
      String -> Q ()
reportError (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
prepare [String]
errors
      ByteString -> Q ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
forall a. Monoid a => a
mempty
    Right bs :: ByteString
bs -> ByteString -> Q ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs

  ByteString -> Q Exp
bsToExp ByteString
bs

 where
  prepare :: [String] -> String
prepare [singleLine :: String
singleLine] = String
singleLine
  prepare multiline :: [String]
multiline =
    String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate "\n" ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ "glslangValidator:" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: (String -> String) -> [String] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (String -> String -> String
forall a. Monoid a => a -> a -> a
mappend "        ") [String]
multiline

type GLSLError = String
type GLSLWarning = String

-- | Compile a glsl shader to spir-v using glslangValidator
compileShader
  :: MonadIO m
  => Maybe Loc
  -- ^ Source location
  -> Maybe String
  -- ^ Argument to pass to `--target-env`
  -> String
  -- ^ stage
  -> String
  -- ^ glsl code
  -> m ([GLSLWarning], Either [GLSLError] ByteString)
  -- ^ Spir-V bytecode with warnings or errors
compileShader :: Maybe Loc
-> Maybe String
-> String
-> String
-> m ([String], Either [String] ByteString)
compileShader loc :: Maybe Loc
loc targetEnv :: Maybe String
targetEnv stage :: String
stage code :: String
code =
  IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ([String], Either [String] ByteString)
 -> m ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ String
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String -> (String -> m a) -> m a
withSystemTempDirectory "th-shader" ((String -> IO ([String], Either [String] ByteString))
 -> IO ([String], Either [String] ByteString))
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ \dir :: String
dir -> do
    let codeWithLineDirective :: String
codeWithLineDirective = String -> (Loc -> String) -> Maybe Loc -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
code (String -> Loc -> String
insertLineDirective String
code) Maybe Loc
loc
    let shader :: String
shader = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "/shader." String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
stage
        spirv :: String
spirv  = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "/shader.spv"
    String -> String -> IO ()
writeFile String
shader String
codeWithLineDirective

    let targetArgs :: [String]
targetArgs = case Maybe String
targetEnv of
          Nothing -> []
          Just t :: String
t  -> ["--target-env", String
t]
        args :: [String]
args = [String]
targetArgs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ ["-S", String
stage, "-V", String
shader, "-o", String
spirv]
    (rc :: ExitCode
rc, out :: ByteString
out, err :: ByteString
err) <- ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall (m :: * -> *) stdin stdoutIgnored stderrIgnored.
MonadIO m =>
ProcessConfig stdin stdoutIgnored stderrIgnored
-> m (ExitCode, ByteString, ByteString)
readProcess (ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString))
-> ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ String -> [String] -> ProcessConfig () () ()
proc "glslangValidator" [String]
args
    let (warnings :: [String]
warnings, errors :: [String]
errors) = ByteString -> ([String], [String])
processValidatorMessages (ByteString
out ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
err)
    case ExitCode
rc of
      ExitSuccess -> do
        ByteString
bs <- String -> IO ByteString
BS.readFile String
spirv
        ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, ByteString -> Either [String] ByteString
forall a b. b -> Either a b
Right ByteString
bs)
      ExitFailure _rc :: Int
_rc -> ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, [String] -> Either [String] ByteString
forall a b. a -> Either a b
Left [String]
errors)

processValidatorMessages :: BSL.ByteString -> ([GLSLWarning], [GLSLError])
processValidatorMessages :: ByteString -> ([String], [String])
processValidatorMessages =
  (String -> ([String], [String]) -> ([String], [String]))
-> ([String], [String]) -> [String] -> ([String], [String])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr String -> ([String], [String]) -> ([String], [String])
grep ([], []) ([String] -> ([String], [String]))
-> (ByteString -> [String]) -> ByteString -> ([String], [String])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (String -> Bool) -> String -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) ([String] -> [String])
-> (ByteString -> [String]) -> ByteString -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
lines (String -> [String])
-> (ByteString -> String) -> ByteString -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
BSL.unpack
 where
  grep :: String -> ([String], [String]) -> ([String], [String])
grep line :: String
line (ws :: [String]
ws, es :: [String]
es) | "WARNING: " String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
line = (String -> String
cut String
line String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
ws, [String]
es)
                     | "ERROR: " String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
line   = ([String]
ws, String -> String
cut String
line String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
es)
                     | Bool
otherwise                     = ([String]
ws, [String]
es)

  cut :: String -> String
cut line :: String
line = String -> String
takeFileName String
path String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
msg
    where (path :: String
path, msg :: String
msg) = (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== ':') (String -> (String, String))
-> (String -> String) -> String -> (String, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String -> String
forall a. Int -> [a] -> [a]
drop 1 (String -> (String, String)) -> String -> (String, String)
forall a b. (a -> b) -> a -> b
$ (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= ' ') String
line

-- If possible, insert a #line directive after the #version directive (as well
-- as the extension which allows filenames in line directives.
insertLineDirective :: String -> Loc -> String
insertLineDirective :: String -> Loc -> String
insertLineDirective code :: String
code Loc {..} =
  let isVersionDirective :: String -> Bool
isVersionDirective = ("#version" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf`) (String -> Bool) -> (String -> String) -> String -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
dropWhile Char -> Bool
isSpace
      codeLines :: [String]
codeLines = String -> [String]
lines String
code
      (beforeVersion :: [String]
beforeVersion, afterVersion :: [String]
afterVersion) = (String -> Bool) -> [String] -> ([String], [String])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break String -> Bool
isVersionDirective [String]
codeLines
      lineDirective :: [String]
lineDirective =
        [ "#extension GL_GOOGLE_cpp_style_line_directive : enable"
        , "#line "
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (CharPos -> Int
forall a b. (a, b) -> a
fst CharPos
loc_start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [String] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
beforeVersion Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1)
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> " \""
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
loc_filename
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "\""
        ]
  in  case [String]
afterVersion of
        []     -> String
code
        v :: String
v : xs :: [String]
xs -> [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ [String]
beforeVersion [String] -> [String] -> [String]
forall a. Semigroup a => a -> a -> a
<> [String
v] [String] -> [String] -> [String]
forall a. Semigroup a => a -> a -> a
<> [String]
lineDirective [String] -> [String] -> [String]
forall a. Semigroup a => a -> a -> a
<> [String]
xs