Validate required input fields

This commit is contained in:
Eugen Wissner 2020-10-04 18:51:21 +02:00
parent d5f518fe82
commit a91bc7f2d2
9 changed files with 128 additions and 48 deletions

View File

@ -54,6 +54,7 @@ and this project adheres to
- `knownDirectiveNamesRule` - `knownDirectiveNamesRule`
- `directivesInValidLocationsRule` - `directivesInValidLocationsRule`
- `providedRequiredArgumentsRule` - `providedRequiredArgumentsRule`
- `providedRequiredInputFieldsRule`
- `AST.Document.Field`. - `AST.Document.Field`.
- `AST.Document.FragmentSpread`. - `AST.Document.FragmentSpread`.
- `AST.Document.InlineFragment`. - `AST.Document.InlineFragment`.

View File

@ -1,5 +1,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Safe #-} {-# LANGUAGE Safe #-}
-- | This module defines an abstract syntax tree for the @GraphQL@ language. It -- | This module defines an abstract syntax tree for the @GraphQL@ language. It
@ -75,10 +77,13 @@ instance Ord Location where
-- | Contains some tree node with a location. -- | Contains some tree node with a location.
data Node a = Node data Node a = Node
{ value :: a { node :: a
, location :: Location , location :: Location
} deriving (Eq, Show) } deriving (Eq, Show)
instance Functor Node where
fmap f Node{..} = Node (f node) location
-- ** Document -- ** Document
-- | GraphQL document. -- | GraphQL document.
@ -241,8 +246,11 @@ data ConstValue
-- | Key-value pair. -- | Key-value pair.
-- --
-- A list of 'ObjectField's represents a GraphQL object type. -- A list of 'ObjectField's represents a GraphQL object type.
data ObjectField a = ObjectField Name a Location data ObjectField a = ObjectField
deriving (Eq, Show) { name :: Name
, value :: Node a
, location :: Location
} deriving (Eq, Show)
-- ** Variables -- ** Variables

View File

@ -1,6 +1,7 @@
{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Safe #-} {-# LANGUAGE Safe #-}
-- | This module defines a minifier and a printer for the @GraphQL@ language. -- | This module defines a minifier and a printer for the @GraphQL@ language.
@ -101,7 +102,7 @@ variableDefinition formatter variableDefinition' =
in variable variableName in variable variableName
<> eitherFormat formatter ": " ":" <> eitherFormat formatter ": " ":"
<> type' variableType <> type' variableType
<> maybe mempty (defaultValue formatter) (Full.value <$> defaultValue') <> maybe mempty (defaultValue formatter) (Full.node <$> defaultValue')
defaultValue :: Formatter -> Full.ConstValue -> Lazy.Text defaultValue :: Formatter -> Full.ConstValue -> Lazy.Text
defaultValue formatter val defaultValue formatter val
@ -164,7 +165,7 @@ argument :: Formatter -> Full.Argument -> Lazy.Text
argument formatter (Full.Argument name value' _) argument formatter (Full.Argument name value' _)
= Lazy.Text.fromStrict name = Lazy.Text.fromStrict name
<> colon formatter <> colon formatter
<> value formatter (Full.value value') <> value formatter (Full.node value')
-- * Fragments -- * Fragments
@ -222,8 +223,8 @@ fromConstValue (Full.ConstEnum x) = Full.Enum x
fromConstValue (Full.ConstList x) = Full.List $ fromConstValue <$> x fromConstValue (Full.ConstList x) = Full.List $ fromConstValue <$> x
fromConstValue (Full.ConstObject x) = Full.Object $ fromConstObjectField <$> x fromConstValue (Full.ConstObject x) = Full.Object $ fromConstObjectField <$> x
where where
fromConstObjectField (Full.ObjectField key value' location) = fromConstObjectField Full.ObjectField{value = value', ..} =
Full.ObjectField key (fromConstValue value') location Full.ObjectField name (fromConstValue <$> value') location
booleanValue :: Bool -> Lazy.Text booleanValue :: Bool -> Lazy.Text
booleanValue True = "true" booleanValue True = "true"
@ -292,7 +293,7 @@ objectValue formatter = intercalate $ objectField formatter
. fmap f . fmap f
objectField :: Formatter -> Full.ObjectField Full.Value -> Lazy.Text objectField :: Formatter -> Full.ObjectField Full.Value -> Lazy.Text
objectField formatter (Full.ObjectField name value' _) = objectField formatter (Full.ObjectField name (Full.Node value' _) _) =
Lazy.Text.fromStrict name <> colon formatter <> value formatter value' Lazy.Text.fromStrict name <> colon formatter <> value formatter value'
-- | Converts a 'Type' a type into a string. -- | Converts a 'Type' a type into a string.

View File

@ -461,7 +461,7 @@ value = Full.Variable <$> variable
<|> Full.String <$> stringValue <|> Full.String <$> stringValue
<|> Full.Enum <$> try enumValue <|> Full.Enum <$> try enumValue
<|> Full.List <$> brackets (some value) <|> Full.List <$> brackets (some value)
<|> Full.Object <$> braces (some $ objectField value) <|> Full.Object <$> braces (some $ objectField $ valueNode value)
<?> "Value" <?> "Value"
constValue :: Parser Full.ConstValue constValue :: Parser Full.ConstValue
@ -472,7 +472,7 @@ constValue = Full.ConstFloat <$> try float
<|> Full.ConstString <$> stringValue <|> Full.ConstString <$> stringValue
<|> Full.ConstEnum <$> try enumValue <|> Full.ConstEnum <$> try enumValue
<|> Full.ConstList <$> brackets (some constValue) <|> Full.ConstList <$> brackets (some constValue)
<|> Full.ConstObject <$> braces (some $ objectField constValue) <|> Full.ConstObject <$> braces (some $ objectField $ valueNode constValue)
<?> "Value" <?> "Value"
booleanValue :: Parser Bool booleanValue :: Parser Bool
@ -493,7 +493,7 @@ stringValue = blockString <|> string <?> "StringValue"
nullValue :: Parser Text nullValue :: Parser Text
nullValue = symbol "null" <?> "NullValue" nullValue = symbol "null" <?> "NullValue"
objectField :: Parser a -> Parser (Full.ObjectField a) objectField :: forall a. Parser (Full.Node a) -> Parser (Full.ObjectField a)
objectField valueParser = label "ObjectField" $ do objectField valueParser = label "ObjectField" $ do
location <- getLocation location <- getLocation
fieldName <- name fieldName <- name

View File

@ -153,7 +153,7 @@ coerceVariableValues types operationDefinition variableValues =
forEach variableDefinition coercedValues = do forEach variableDefinition coercedValues = do
let Full.VariableDefinition variableName variableTypeName defaultValue _ = let Full.VariableDefinition variableName variableTypeName defaultValue _ =
variableDefinition variableDefinition
let defaultValue' = constValue . Full.value <$> defaultValue let defaultValue' = constValue . Full.node <$> defaultValue
variableType <- lookupInputType variableTypeName types variableType <- lookupInputType variableTypeName types
Coerce.matchFieldValues Coerce.matchFieldValues
@ -178,7 +178,8 @@ constValue (Full.ConstList l) = Type.List $ constValue <$> l
constValue (Full.ConstObject o) = constValue (Full.ConstObject o) =
Type.Object $ HashMap.fromList $ constObjectField <$> o Type.Object $ HashMap.fromList $ constObjectField <$> o
where where
constObjectField (Full.ObjectField key value' _) = (key, constValue value') constObjectField Full.ObjectField{value = value', ..} =
(name, constValue $ Full.node value')
-- | Rewrites the original syntax tree into an intermediate representation used -- | Rewrites the original syntax tree into an intermediate representation used
-- for query execution. -- for query execution.
@ -384,7 +385,8 @@ value (Full.List list) = Type.List <$> traverse value list
value (Full.Object object) = value (Full.Object object) =
Type.Object . HashMap.fromList <$> traverse objectField object Type.Object . HashMap.fromList <$> traverse objectField object
where where
objectField (Full.ObjectField name value' _) = (name,) <$> value value' objectField Full.ObjectField{value = value', ..} =
(name,) <$> value (Full.node value')
input :: forall m. Full.Value -> State (Replacement m) (Maybe Input) input :: forall m. Full.Value -> State (Replacement m) (Maybe Input)
input (Full.Variable name) = input (Full.Variable name) =
@ -400,8 +402,8 @@ input (Full.Object object) = do
objectFields <- foldM objectField HashMap.empty object objectFields <- foldM objectField HashMap.empty object
pure $ pure $ Object objectFields pure $ pure $ Object objectFields
where where
objectField resultMap (Full.ObjectField name value' _) = objectField resultMap Full.ObjectField{value = value', ..} =
inputField resultMap name value' inputField resultMap name $ Full.node value'
inputField :: forall m inputField :: forall m
. HashMap Full.Name Input . HashMap Full.Name Input

View File

@ -3,6 +3,7 @@
obtain one at https://mozilla.org/MPL/2.0/. -} obtain one at https://mozilla.org/MPL/2.0/. -}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
@ -349,25 +350,25 @@ variableDefinition :: forall m
variableDefinition context rule (Full.VariableDefinition _ typeName value' _) variableDefinition context rule (Full.VariableDefinition _ typeName value' _)
| Just defaultValue' <- value' | Just defaultValue' <- value'
, variableType <- lookupInputType typeName $ Validation.types context = , variableType <- lookupInputType typeName $ Validation.types context =
constValue rule variableType $ Full.value defaultValue' constValue rule variableType defaultValue'
variableDefinition _ _ _ = mempty variableDefinition _ _ _ = mempty
constValue :: forall m constValue :: forall m
. Validation.Rule m . Validation.Rule m
-> Maybe In.Type -> Maybe In.Type
-> Full.ConstValue -> Full.Node Full.ConstValue
-> Seq (Validation.RuleT m) -> Seq (Validation.RuleT m)
constValue (Validation.ValueRule _ rule) valueType = go valueType constValue (Validation.ValueRule _ rule) valueType = go valueType
where where
go inputObjectType value'@(Full.ConstObject fields) go inputObjectType value'@(Full.Node (Full.ConstObject fields) _)
= foldMap (forEach inputObjectType) (Seq.fromList fields) = foldMap (forEach inputObjectType) (Seq.fromList fields)
|> rule inputObjectType value' |> rule inputObjectType value'
go listType value'@(Full.ConstList values) go listType value'@(Full.Node (Full.ConstList values) location')
= foldMap (go $ valueTypeFromList listType) (Seq.fromList values) = embedListLocation go listType values location'
|> rule listType value' |> rule listType value'
go anotherValue value' = pure $ rule anotherValue value' go anotherValue value' = pure $ rule anotherValue value'
forEach inputObjectType (Full.ObjectField fieldName fieldValue _) = forEach inputObjectType Full.ObjectField{value = value', ..} =
go (valueTypeByName fieldName inputObjectType) fieldValue go (valueTypeByName name inputObjectType) value'
constValue _ _ = const mempty constValue _ _ = const mempty
inputFieldType :: In.InputField -> In.Type inputFieldType :: In.InputField -> In.Type
@ -379,10 +380,6 @@ valueTypeByName fieldName (Just( In.InputObjectBaseType inputObjectType)) =
in inputFieldType <$> HashMap.lookup fieldName fieldTypes in inputFieldType <$> HashMap.lookup fieldName fieldTypes
valueTypeByName _ _ = Nothing valueTypeByName _ _ = Nothing
valueTypeFromList :: Maybe In.Type -> Maybe In.Type
valueTypeFromList (Just (In.ListBaseType listType)) = Just listType
valueTypeFromList _ = Nothing
fragmentDefinition :: forall m fragmentDefinition :: forall m
. Validation.Rule m . Validation.Rule m
-> Validation m -> Validation m
@ -468,26 +465,40 @@ argument :: forall m
-> Full.Argument -> Full.Argument
-> Seq (Validation.RuleT m) -> Seq (Validation.RuleT m)
argument rule argumentType (Full.Argument _ value' _) = argument rule argumentType (Full.Argument _ value' _) =
value rule (valueType <$> argumentType) $ Full.value value' value rule (valueType <$> argumentType) value'
where where
valueType (In.Argument _ valueType' _) = valueType' valueType (In.Argument _ valueType' _) = valueType'
-- valueTypeFromList :: Maybe In.Type -> Maybe In.Type
embedListLocation :: forall a m
. (Maybe In.Type -> Full.Node a -> Seq m)
-> Maybe In.Type
-> [a]
-> Full.Location
-> Seq m
embedListLocation go listType values location'
= foldMap (go $ valueTypeFromList listType)
$ flip Full.Node location' <$> Seq.fromList values
where
valueTypeFromList (Just (In.ListBaseType baseType)) = Just baseType
valueTypeFromList _ = Nothing
value :: forall m value :: forall m
. Validation.Rule m . Validation.Rule m
-> Maybe In.Type -> Maybe In.Type
-> Full.Value -> Full.Node Full.Value
-> Seq (Validation.RuleT m) -> Seq (Validation.RuleT m)
value (Validation.ValueRule rule _) valueType = go valueType value (Validation.ValueRule rule _) valueType = go valueType
where where
go inputObjectType value'@(Full.Object fields) go inputObjectType value'@(Full.Node (Full.Object fields) _)
= foldMap (forEach inputObjectType) (Seq.fromList fields) = foldMap (forEach inputObjectType) (Seq.fromList fields)
|> rule inputObjectType value' |> rule inputObjectType value'
go listType value'@(Full.List values) go listType value'@(Full.Node (Full.List values) location')
= foldMap (go $ valueTypeFromList listType) (Seq.fromList values) = embedListLocation go listType values location'
|> rule listType value' |> rule listType value'
go anotherValue value' = pure $ rule anotherValue value' go anotherValue value' = pure $ rule anotherValue value'
forEach inputObjectType (Full.ObjectField fieldName fieldValue _) = forEach inputObjectType Full.ObjectField{value = value', ..} =
go (valueTypeByName fieldName inputObjectType) fieldValue go (valueTypeByName name inputObjectType) value'
value _ _ = const mempty value _ _ = const mempty
inlineFragment :: forall m inlineFragment :: forall m

View File

@ -24,6 +24,7 @@ module Language.GraphQL.Validate.Rules
, noUndefinedVariablesRule , noUndefinedVariablesRule
, noUnusedFragmentsRule , noUnusedFragmentsRule
, noUnusedVariablesRule , noUnusedVariablesRule
, providedRequiredInputFieldsRule
, providedRequiredArgumentsRule , providedRequiredArgumentsRule
, scalarLeafsRule , scalarLeafsRule
, singleFieldSubscriptionsRule , singleFieldSubscriptionsRule
@ -91,6 +92,7 @@ specifiedRules =
-- Values -- Values
, knownInputFieldNamesRule , knownInputFieldNamesRule
, uniqueInputFieldNamesRule , uniqueInputFieldNamesRule
, providedRequiredInputFieldsRule
-- Directives. -- Directives.
, knownDirectiveNamesRule , knownDirectiveNamesRule
, directivesInValidLocationsRule , directivesInValidLocationsRule
@ -646,7 +648,7 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas
findDirectiveVariables (Directive _ arguments _) = mapArguments arguments findDirectiveVariables (Directive _ arguments _) = mapArguments arguments
mapArguments = Seq.fromList . mapMaybe findArgumentVariables mapArguments = Seq.fromList . mapMaybe findArgumentVariables
mapDirectives = foldMap findDirectiveVariables mapDirectives = foldMap findDirectiveVariables
findArgumentVariables (Argument _ Node{ value = Variable value', ..} _) = findArgumentVariables (Argument _ Node{ node = Variable value', ..} _) =
Just (value', [location]) Just (value', [location])
findArgumentVariables _ = Nothing findArgumentVariables _ = Nothing
makeError operationName (variableName, locations') = Error makeError operationName (variableName, locations') = Error
@ -680,12 +682,12 @@ uniqueInputFieldNamesRule :: forall m. Rule m
uniqueInputFieldNamesRule = uniqueInputFieldNamesRule =
ValueRule (const $ lift . go) (const $ lift . constGo) ValueRule (const $ lift . go) (const $ lift . constGo)
where where
go (Object fields) = filterFieldDuplicates fields go (Node (Object fields) _) = filterFieldDuplicates fields
go _ = mempty go _ = mempty
filterFieldDuplicates fields = filterFieldDuplicates fields =
filterDuplicates getFieldName "input field" fields filterDuplicates getFieldName "input field" fields
getFieldName (ObjectField fieldName _ location') = (fieldName, location') getFieldName (ObjectField fieldName _ location') = (fieldName, location')
constGo (ConstObject fields) = filterFieldDuplicates fields constGo (Node (ConstObject fields) _) = filterFieldDuplicates fields
constGo _ = mempty constGo _ = mempty
-- | The target field of a field selection must be defined on the scoped type of -- | The target field of a field selection must be defined on the scoped type of
@ -848,11 +850,11 @@ knownDirectiveNamesRule = DirectivesRule $ const $ \directives' -> do
knownInputFieldNamesRule :: Rule m knownInputFieldNamesRule :: Rule m
knownInputFieldNamesRule = ValueRule go constGo knownInputFieldNamesRule = ValueRule go constGo
where where
go (Just valueType) (Object inputFields) go (Just valueType) (Node (Object inputFields) _)
| In.InputObjectBaseType objectType <- valueType = | In.InputObjectBaseType objectType <- valueType =
lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields
go _ _ = lift mempty go _ _ = lift mempty
constGo (Just valueType) (ConstObject inputFields) constGo (Just valueType) (Node (ConstObject inputFields) _)
| In.InputObjectBaseType objectType <- valueType = | In.InputObjectBaseType objectType <- valueType =
lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields
constGo _ _ = lift mempty constGo _ _ = lift mempty
@ -915,13 +917,6 @@ providedRequiredArgumentsRule = ArgumentsRule fieldRule directiveRule
let forEach = go (directiveMessage directiveName) arguments location' let forEach = go (directiveMessage directiveName) arguments location'
in lift $ HashMap.foldrWithKey forEach Seq.empty definitions in lift $ HashMap.foldrWithKey forEach Seq.empty definitions
_ -> lift mempty _ -> lift mempty
inputTypeName (In.ScalarBaseType (Definition.ScalarType typeName _)) =
typeName
inputTypeName (In.EnumBaseType (Definition.EnumType typeName _ _)) =
typeName
inputTypeName (In.InputObjectBaseType (In.InputObjectType typeName _ _)) =
typeName
inputTypeName (In.ListBaseType listType) = inputTypeName listType
go makeMessage arguments location' argumentName argumentType errors go makeMessage arguments location' argumentName argumentType errors
| In.Argument _ type' optionalValue <- argumentType | In.Argument _ type' optionalValue <- argumentType
, In.isNonNullType type' , In.isNonNullType type'
@ -956,3 +951,49 @@ providedRequiredArgumentsRule = ArgumentsRule fieldRule directiveRule
, Text.unpack typeName , Text.unpack typeName
, "\" is required, but it was not provided." , "\" is required, but it was not provided."
] ]
inputTypeName :: In.Type -> Text
inputTypeName (In.ScalarBaseType (Definition.ScalarType typeName _)) = typeName
inputTypeName (In.EnumBaseType (Definition.EnumType typeName _ _)) = typeName
inputTypeName (In.InputObjectBaseType (In.InputObjectType typeName _ _)) =
typeName
inputTypeName (In.ListBaseType listType) = inputTypeName listType
-- | Input object fields may be required. Much like a field may have required
-- arguments, an input object may have required fields. An input field is
-- required if it has a nonnull type and does not have a default value.
-- Otherwise, the input object field is optional.
providedRequiredInputFieldsRule :: Rule m
providedRequiredInputFieldsRule = ValueRule go constGo
where
go (Just valueType) (Node (Object inputFields) location')
| In.InputObjectBaseType objectType <- valueType
, In.InputObjectType objectTypeName _ fieldDefinitions <- objectType
= lift
$ Seq.fromList
$ HashMap.elems
$ flip HashMap.mapMaybeWithKey fieldDefinitions
$ forEach inputFields objectTypeName location'
go _ _ = lift mempty
constGo _ _ = lift mempty
forEach inputFields typeName location' definitionName fieldDefinition
| In.InputField _ inputType optionalValue <- fieldDefinition
, In.isNonNullType inputType
, isNothing optionalValue
, isNothingOrNull $ find (lookupField definitionName) inputFields =
Just $ makeError definitionName typeName location'
| otherwise = Nothing
isNothingOrNull (Just (ObjectField _ (Node Null _) _)) = True
isNothingOrNull x = isNothing x
lookupField needle (ObjectField fieldName _ _) = needle == fieldName
makeError fieldName typeName location' = Error
{ message = errorMessage fieldName typeName
, locations = [location']
}
errorMessage fieldName typeName = concat
[ "Input field \""
, Text.unpack fieldName
, "\" of type \""
, Text.unpack typeName
, "\" is required, but it was not provided."
]

View File

@ -48,7 +48,7 @@ data Rule m
| ArgumentsRule (Maybe (Out.Type m) -> Field -> RuleT m) (Directive -> RuleT m) | ArgumentsRule (Maybe (Out.Type m) -> Field -> RuleT m) (Directive -> RuleT m)
| DirectivesRule (DirectiveLocation -> [Directive] -> RuleT m) | DirectivesRule (DirectiveLocation -> [Directive] -> RuleT m)
| VariablesRule ([VariableDefinition] -> RuleT m) | VariablesRule ([VariableDefinition] -> RuleT m)
| ValueRule (Maybe In.Type -> Value -> RuleT m) (Maybe In.Type -> ConstValue -> RuleT m) | ValueRule (Maybe In.Type -> Node Value -> RuleT m) (Maybe In.Type -> Node ConstValue -> RuleT m)
-- | Monad transformer used by the rules. -- | Monad transformer used by the rules.
type RuleT m = ReaderT (Validation m) Seq Error type RuleT m = ReaderT (Validation m) Seq Error

View File

@ -594,7 +594,7 @@ spec =
it "rejects undefined input object fields" $ it "rejects undefined input object fields" $
let queryString = [r| let queryString = [r|
{ {
findDog(complex: { favoriteCookieFlavor: "Bacon" }) { findDog(complex: { favoriteCookieFlavor: "Bacon", name: "Jack" }) {
name name
} }
} }
@ -620,3 +620,19 @@ spec =
, locations = [AST.Location 2 21] , locations = [AST.Location 2 21]
} }
in validate queryString `shouldBe` [expected] in validate queryString `shouldBe` [expected]
it "rejects missing required input fields" $
let queryString = [r|
{
findDog(complex: { name: null }) {
name
}
}
|]
expected = Error
{ message =
"Input field \"name\" of type \"DogData\" is required, \
\but it was not provided."
, locations = [AST.Location 3 34]
}
in validate queryString `shouldBe` [expected]