Validate fragments are input types

This commit is contained in:
Eugen Wissner 2020-09-20 06:59:27 +02:00
parent 21a7d9cce4
commit 38c3097bcf
6 changed files with 80 additions and 36 deletions

View File

@ -42,6 +42,7 @@ and this project adheres to
- `uniqueArgumentNamesRule` - `uniqueArgumentNamesRule`
- `uniqueDirectiveNamesRule` - `uniqueDirectiveNamesRule`
- `uniqueVariableNamesRule` - `uniqueVariableNamesRule`
- `variablesAreInputTypesRule`
- `AST.Document.Field`. - `AST.Document.Field`.
- `AST.Document.FragmentSpread`. - `AST.Document.FragmentSpread`.
- `AST.Document.InlineFragment`. - `AST.Document.InlineFragment`.

View File

@ -47,7 +47,6 @@ import Language.GraphQL.AST (Name)
import qualified Language.GraphQL.Execute.Coerce as Coerce import qualified Language.GraphQL.Execute.Coerce as Coerce
import qualified Language.GraphQL.Type.Definition as Definition import qualified Language.GraphQL.Type.Definition as Definition
import qualified Language.GraphQL.Type as Type import qualified Language.GraphQL.Type as Type
import qualified Language.GraphQL.Type.In as In
import Language.GraphQL.Type.Internal import Language.GraphQL.Type.Internal
import qualified Language.GraphQL.Type.Out as Out import qualified Language.GraphQL.Type.Out as Out
import Language.GraphQL.Type.Schema import Language.GraphQL.Type.Schema
@ -139,35 +138,6 @@ getOperation (Just operationName) operations
matchingName (OperationDefinition _ name _ _ _) = matchingName (OperationDefinition _ name _ _ _) =
name == Just operationName name == Just operationName
lookupInputType
:: Full.Type
-> HashMap.HashMap Full.Name (Type m)
-> Maybe In.Type
lookupInputType (Full.TypeNamed name) types =
case HashMap.lookup name types of
Just (ScalarType scalarType) ->
Just $ In.NamedScalarType scalarType
Just (EnumType enumType) ->
Just $ In.NamedEnumType enumType
Just (InputObjectType objectType) ->
Just $ In.NamedInputObjectType objectType
_ -> Nothing
lookupInputType (Full.TypeList list) types
= In.ListType
<$> lookupInputType list types
lookupInputType (Full.TypeNonNull (Full.NonNullTypeNamed nonNull)) types =
case HashMap.lookup nonNull types of
Just (ScalarType scalarType) ->
Just $ In.NonNullScalarType scalarType
Just (EnumType enumType) ->
Just $ In.NonNullEnumType enumType
Just (InputObjectType objectType) ->
Just $ In.NonNullInputObjectType objectType
_ -> Nothing
lookupInputType (Full.TypeNonNull (Full.NonNullTypeList nonNull)) types
= In.NonNullListType
<$> lookupInputType nonNull types
coerceVariableValues :: Coerce.VariableValue a coerceVariableValues :: Coerce.VariableValue a
=> forall m => forall m
. HashMap Full.Name (Type m) . HashMap Full.Name (Type m)

View File

@ -10,12 +10,13 @@ module Language.GraphQL.Type.Internal
, collectReferencedTypes , collectReferencedTypes
, doesFragmentTypeApply , doesFragmentTypeApply
, instanceOf , instanceOf
, lookupInputType
, lookupTypeCondition , lookupTypeCondition
) where ) where
import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap import qualified Data.HashMap.Strict as HashMap
import Language.GraphQL.AST (Name) import qualified Language.GraphQL.AST as Full
import qualified Language.GraphQL.Type.Definition as Definition import qualified Language.GraphQL.Type.Definition as Definition
import qualified Language.GraphQL.Type.In as In import qualified Language.GraphQL.Type.In as In
import qualified Language.GraphQL.Type.Out as Out import qualified Language.GraphQL.Type.Out as Out
@ -35,7 +36,7 @@ data AbstractType m
deriving Eq deriving Eq
-- | Traverses the schema and finds all referenced types. -- | Traverses the schema and finds all referenced types.
collectReferencedTypes :: forall m. Schema m -> HashMap Name (Type m) collectReferencedTypes :: forall m. Schema m -> HashMap Full.Name (Type m)
collectReferencedTypes schema = collectReferencedTypes schema =
let queryTypes = traverseObjectType (query schema) HashMap.empty let queryTypes = traverseObjectType (query schema) HashMap.empty
mutationTypes = maybe queryTypes (`traverseObjectType` queryTypes) mutationTypes = maybe queryTypes (`traverseObjectType` queryTypes)
@ -121,8 +122,8 @@ instanceOf objectType (AbstractUnionType unionType) =
go unionMemberType acc = acc || objectType == unionMemberType go unionMemberType acc = acc || objectType == unionMemberType
lookupTypeCondition :: forall m lookupTypeCondition :: forall m
. Name . Full.Name
-> HashMap Name (Type m) -> HashMap Full.Name (Type m)
-> Maybe (CompositeType m) -> Maybe (CompositeType m)
lookupTypeCondition type' types' = lookupTypeCondition type' types' =
case HashMap.lookup type' types' of case HashMap.lookup type' types' of
@ -131,3 +132,32 @@ lookupTypeCondition type' types' =
Just (InterfaceType interfaceType) -> Just (InterfaceType interfaceType) ->
Just $ CompositeInterfaceType interfaceType Just $ CompositeInterfaceType interfaceType
_ -> Nothing _ -> Nothing
lookupInputType
:: Full.Type
-> HashMap.HashMap Full.Name (Type m)
-> Maybe In.Type
lookupInputType (Full.TypeNamed name) types =
case HashMap.lookup name types of
Just (ScalarType scalarType) ->
Just $ In.NamedScalarType scalarType
Just (EnumType enumType) ->
Just $ In.NamedEnumType enumType
Just (InputObjectType objectType) ->
Just $ In.NamedInputObjectType objectType
_ -> Nothing
lookupInputType (Full.TypeList list) types
= In.ListType
<$> lookupInputType list types
lookupInputType (Full.TypeNonNull (Full.NonNullTypeNamed nonNull)) types =
case HashMap.lookup nonNull types of
Just (ScalarType scalarType) ->
Just $ In.NonNullScalarType scalarType
Just (EnumType enumType) ->
Just $ In.NonNullEnumType enumType
Just (InputObjectType objectType) ->
Just $ In.NonNullInputObjectType objectType
_ -> Nothing
lookupInputType (Full.TypeNonNull (Full.NonNullTypeList nonNull)) types
= In.NonNullListType
<$> lookupInputType nonNull types

View File

@ -23,9 +23,10 @@ module Language.GraphQL.Validate.Rules
, uniqueFragmentNamesRule , uniqueFragmentNamesRule
, uniqueOperationNamesRule , uniqueOperationNamesRule
, uniqueVariableNamesRule , uniqueVariableNamesRule
, variablesAreInputTypesRule
) where ) where
import Control.Monad (foldM) import Control.Monad ((>=>), foldM)
import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Reader (ReaderT, asks) import Control.Monad.Trans.Reader (ReaderT, asks)
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
@ -67,6 +68,7 @@ specifiedRules =
, uniqueDirectiveNamesRule , uniqueDirectiveNamesRule
-- Variables. -- Variables.
, uniqueVariableNamesRule , uniqueVariableNamesRule
, variablesAreInputTypesRule
] ]
-- | Definition must be OperationDefinition or FragmentDefinition. -- | Definition must be OperationDefinition or FragmentDefinition.
@ -505,3 +507,29 @@ uniqueVariableNamesRule = VariablesRule
where where
extract (VariableDefinition variableName _ _ location) = extract (VariableDefinition variableName _ _ location) =
(variableName, location) (variableName, location)
-- | Variables can only be input types. Objects, unions and interfaces cannot be
-- used as inputs.
variablesAreInputTypesRule :: forall m. Rule m
variablesAreInputTypesRule = VariablesRule
$ (traverse check . Seq.fromList) >=> lift
where
check (VariableDefinition name typeName _ location)
= asks types
>>= lift
. maybe (makeError name typeName location) (const mempty)
. lookupInputType typeName
makeError name typeName location = pure $ Error
{ message = concat
[ "Variable \"$"
, Text.unpack name
, "\" cannot be non-input type \""
, Text.unpack $ getTypeName typeName
, "\"."
]
, locations = [location]
}
getTypeName (TypeNamed name) = name
getTypeName (TypeList name) = getTypeName name
getTypeName (TypeNonNull (NonNullTypeNamed nonNull)) = nonNull
getTypeName (TypeNonNull (NonNullTypeList nonNull)) = getTypeName nonNull

View File

@ -1,4 +1,4 @@
resolver: lts-16.14 resolver: lts-16.15
packages: packages:
- . - .

View File

@ -456,3 +456,18 @@ spec =
, locations = [AST.Location 2 39, AST.Location 2 63] , locations = [AST.Location 2 39, AST.Location 2 63]
} }
in validate queryString `shouldBe` Seq.singleton expected in validate queryString `shouldBe` Seq.singleton expected
it "rejects non-input types as variables" $
let queryString = [r|
query takesDogBang($dog: Dog!) {
dog {
isHousetrained(atOtherHomes: $dog)
}
}
|]
expected = Error
{ message =
"Variable \"$dog\" cannot be non-input type \"Dog\"."
, locations = [AST.Location 2 34]
}
in validate queryString `shouldBe` Seq.singleton expected