From a91bc7f2d218ea2df308d3968587b60351625150 Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Sun, 4 Oct 2020 18:51:21 +0200 Subject: [PATCH] Validate required input fields --- CHANGELOG.md | 1 + src/Language/GraphQL/AST/Document.hs | 14 ++++- src/Language/GraphQL/AST/Encoder.hs | 11 ++-- src/Language/GraphQL/AST/Parser.hs | 6 +- src/Language/GraphQL/Execute/Transform.hs | 12 ++-- src/Language/GraphQL/Validate.hs | 47 +++++++++------ src/Language/GraphQL/Validate/Rules.hs | 65 +++++++++++++++++---- src/Language/GraphQL/Validate/Validation.hs | 2 +- tests/Language/GraphQL/ValidateSpec.hs | 18 +++++- 9 files changed, 128 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec1eecf..a7e9c41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ and this project adheres to - `knownDirectiveNamesRule` - `directivesInValidLocationsRule` - `providedRequiredArgumentsRule` + - `providedRequiredInputFieldsRule` - `AST.Document.Field`. - `AST.Document.FragmentSpread`. - `AST.Document.InlineFragment`. diff --git a/src/Language/GraphQL/AST/Document.hs b/src/Language/GraphQL/AST/Document.hs index 489a242..b30271c 100644 --- a/src/Language/GraphQL/AST/Document.hs +++ b/src/Language/GraphQL/AST/Document.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE Safe #-} -- | This module defines an abstract syntax tree for the @GraphQL@ language. It @@ -75,10 +77,13 @@ instance Ord Location where -- | Contains some tree node with a location. data Node a = Node - { value :: a + { node :: a , location :: Location } deriving (Eq, Show) +instance Functor Node where + fmap f Node{..} = Node (f node) location + -- ** Document -- | GraphQL document. @@ -241,8 +246,11 @@ data ConstValue -- | Key-value pair. -- -- A list of 'ObjectField's represents a GraphQL object type. -data ObjectField a = ObjectField Name a Location - deriving (Eq, Show) +data ObjectField a = ObjectField + { name :: Name + , value :: Node a + , location :: Location + } deriving (Eq, Show) -- ** Variables diff --git a/src/Language/GraphQL/AST/Encoder.hs b/src/Language/GraphQL/AST/Encoder.hs index 51a801e..9b3eea3 100644 --- a/src/Language/GraphQL/AST/Encoder.hs +++ b/src/Language/GraphQL/AST/Encoder.hs @@ -1,6 +1,7 @@ {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE Safe #-} -- | This module defines a minifier and a printer for the @GraphQL@ language. @@ -101,7 +102,7 @@ variableDefinition formatter variableDefinition' = in variable variableName <> eitherFormat formatter ": " ":" <> type' variableType - <> maybe mempty (defaultValue formatter) (Full.value <$> defaultValue') + <> maybe mempty (defaultValue formatter) (Full.node <$> defaultValue') defaultValue :: Formatter -> Full.ConstValue -> Lazy.Text defaultValue formatter val @@ -164,7 +165,7 @@ argument :: Formatter -> Full.Argument -> Lazy.Text argument formatter (Full.Argument name value' _) = Lazy.Text.fromStrict name <> colon formatter - <> value formatter (Full.value value') + <> value formatter (Full.node value') -- * Fragments @@ -222,8 +223,8 @@ fromConstValue (Full.ConstEnum x) = Full.Enum x fromConstValue (Full.ConstList x) = Full.List $ fromConstValue <$> x fromConstValue (Full.ConstObject x) = Full.Object $ fromConstObjectField <$> x where - fromConstObjectField (Full.ObjectField key value' location) = - Full.ObjectField key (fromConstValue value') location + fromConstObjectField Full.ObjectField{value = value', ..} = + Full.ObjectField name (fromConstValue <$> value') location booleanValue :: Bool -> Lazy.Text booleanValue True = "true" @@ -292,7 +293,7 @@ objectValue formatter = intercalate $ objectField formatter . fmap f objectField :: Formatter -> Full.ObjectField Full.Value -> Lazy.Text -objectField formatter (Full.ObjectField name value' _) = +objectField formatter (Full.ObjectField name (Full.Node value' _) _) = Lazy.Text.fromStrict name <> colon formatter <> value formatter value' -- | Converts a 'Type' a type into a string. diff --git a/src/Language/GraphQL/AST/Parser.hs b/src/Language/GraphQL/AST/Parser.hs index 2695e6f..46c8fa3 100644 --- a/src/Language/GraphQL/AST/Parser.hs +++ b/src/Language/GraphQL/AST/Parser.hs @@ -461,7 +461,7 @@ value = Full.Variable <$> variable <|> Full.String <$> stringValue <|> Full.Enum <$> try enumValue <|> Full.List <$> brackets (some value) - <|> Full.Object <$> braces (some $ objectField value) + <|> Full.Object <$> braces (some $ objectField $ valueNode value) "Value" constValue :: Parser Full.ConstValue @@ -472,7 +472,7 @@ constValue = Full.ConstFloat <$> try float <|> Full.ConstString <$> stringValue <|> Full.ConstEnum <$> try enumValue <|> Full.ConstList <$> brackets (some constValue) - <|> Full.ConstObject <$> braces (some $ objectField constValue) + <|> Full.ConstObject <$> braces (some $ objectField $ valueNode constValue) "Value" booleanValue :: Parser Bool @@ -493,7 +493,7 @@ stringValue = blockString <|> string "StringValue" nullValue :: Parser Text nullValue = symbol "null" "NullValue" -objectField :: Parser a -> Parser (Full.ObjectField a) +objectField :: forall a. Parser (Full.Node a) -> Parser (Full.ObjectField a) objectField valueParser = label "ObjectField" $ do location <- getLocation fieldName <- name diff --git a/src/Language/GraphQL/Execute/Transform.hs b/src/Language/GraphQL/Execute/Transform.hs index 80e7a83..a8a2ae2 100644 --- a/src/Language/GraphQL/Execute/Transform.hs +++ b/src/Language/GraphQL/Execute/Transform.hs @@ -153,7 +153,7 @@ coerceVariableValues types operationDefinition variableValues = forEach variableDefinition coercedValues = do let Full.VariableDefinition variableName variableTypeName defaultValue _ = variableDefinition - let defaultValue' = constValue . Full.value <$> defaultValue + let defaultValue' = constValue . Full.node <$> defaultValue variableType <- lookupInputType variableTypeName types Coerce.matchFieldValues @@ -178,7 +178,8 @@ constValue (Full.ConstList l) = Type.List $ constValue <$> l constValue (Full.ConstObject o) = Type.Object $ HashMap.fromList $ constObjectField <$> o where - constObjectField (Full.ObjectField key value' _) = (key, constValue value') + constObjectField Full.ObjectField{value = value', ..} = + (name, constValue $ Full.node value') -- | Rewrites the original syntax tree into an intermediate representation used -- for query execution. @@ -384,7 +385,8 @@ value (Full.List list) = Type.List <$> traverse value list value (Full.Object object) = Type.Object . HashMap.fromList <$> traverse objectField object where - objectField (Full.ObjectField name value' _) = (name,) <$> value value' + objectField Full.ObjectField{value = value', ..} = + (name,) <$> value (Full.node value') input :: forall m. Full.Value -> State (Replacement m) (Maybe Input) input (Full.Variable name) = @@ -400,8 +402,8 @@ input (Full.Object object) = do objectFields <- foldM objectField HashMap.empty object pure $ pure $ Object objectFields where - objectField resultMap (Full.ObjectField name value' _) = - inputField resultMap name value' + objectField resultMap Full.ObjectField{value = value', ..} = + inputField resultMap name $ Full.node value' inputField :: forall m . HashMap Full.Name Input diff --git a/src/Language/GraphQL/Validate.hs b/src/Language/GraphQL/Validate.hs index 5acb26a..d904e8c 100644 --- a/src/Language/GraphQL/Validate.hs +++ b/src/Language/GraphQL/Validate.hs @@ -3,6 +3,7 @@ obtain one at https://mozilla.org/MPL/2.0/. -} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -349,25 +350,25 @@ variableDefinition :: forall m variableDefinition context rule (Full.VariableDefinition _ typeName value' _) | Just defaultValue' <- value' , variableType <- lookupInputType typeName $ Validation.types context = - constValue rule variableType $ Full.value defaultValue' + constValue rule variableType defaultValue' variableDefinition _ _ _ = mempty constValue :: forall m . Validation.Rule m -> Maybe In.Type - -> Full.ConstValue + -> Full.Node Full.ConstValue -> Seq (Validation.RuleT m) constValue (Validation.ValueRule _ rule) valueType = go valueType where - go inputObjectType value'@(Full.ConstObject fields) + go inputObjectType value'@(Full.Node (Full.ConstObject fields) _) = foldMap (forEach inputObjectType) (Seq.fromList fields) |> rule inputObjectType value' - go listType value'@(Full.ConstList values) - = foldMap (go $ valueTypeFromList listType) (Seq.fromList values) + go listType value'@(Full.Node (Full.ConstList values) location') + = embedListLocation go listType values location' |> rule listType value' go anotherValue value' = pure $ rule anotherValue value' - forEach inputObjectType (Full.ObjectField fieldName fieldValue _) = - go (valueTypeByName fieldName inputObjectType) fieldValue + forEach inputObjectType Full.ObjectField{value = value', ..} = + go (valueTypeByName name inputObjectType) value' constValue _ _ = const mempty inputFieldType :: In.InputField -> In.Type @@ -379,10 +380,6 @@ valueTypeByName fieldName (Just( In.InputObjectBaseType inputObjectType)) = in inputFieldType <$> HashMap.lookup fieldName fieldTypes valueTypeByName _ _ = Nothing -valueTypeFromList :: Maybe In.Type -> Maybe In.Type -valueTypeFromList (Just (In.ListBaseType listType)) = Just listType -valueTypeFromList _ = Nothing - fragmentDefinition :: forall m . Validation.Rule m -> Validation m @@ -468,26 +465,40 @@ argument :: forall m -> Full.Argument -> Seq (Validation.RuleT m) argument rule argumentType (Full.Argument _ value' _) = - value rule (valueType <$> argumentType) $ Full.value value' + value rule (valueType <$> argumentType) value' where valueType (In.Argument _ valueType' _) = valueType' +-- valueTypeFromList :: Maybe In.Type -> Maybe In.Type +embedListLocation :: forall a m + . (Maybe In.Type -> Full.Node a -> Seq m) + -> Maybe In.Type + -> [a] + -> Full.Location + -> Seq m +embedListLocation go listType values location' + = foldMap (go $ valueTypeFromList listType) + $ flip Full.Node location' <$> Seq.fromList values + where + valueTypeFromList (Just (In.ListBaseType baseType)) = Just baseType + valueTypeFromList _ = Nothing + value :: forall m . Validation.Rule m -> Maybe In.Type - -> Full.Value + -> Full.Node Full.Value -> Seq (Validation.RuleT m) value (Validation.ValueRule rule _) valueType = go valueType where - go inputObjectType value'@(Full.Object fields) + go inputObjectType value'@(Full.Node (Full.Object fields) _) = foldMap (forEach inputObjectType) (Seq.fromList fields) |> rule inputObjectType value' - go listType value'@(Full.List values) - = foldMap (go $ valueTypeFromList listType) (Seq.fromList values) + go listType value'@(Full.Node (Full.List values) location') + = embedListLocation go listType values location' |> rule listType value' go anotherValue value' = pure $ rule anotherValue value' - forEach inputObjectType (Full.ObjectField fieldName fieldValue _) = - go (valueTypeByName fieldName inputObjectType) fieldValue + forEach inputObjectType Full.ObjectField{value = value', ..} = + go (valueTypeByName name inputObjectType) value' value _ _ = const mempty inlineFragment :: forall m diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index a5754c6..11f4482 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -24,6 +24,7 @@ module Language.GraphQL.Validate.Rules , noUndefinedVariablesRule , noUnusedFragmentsRule , noUnusedVariablesRule + , providedRequiredInputFieldsRule , providedRequiredArgumentsRule , scalarLeafsRule , singleFieldSubscriptionsRule @@ -91,6 +92,7 @@ specifiedRules = -- Values , knownInputFieldNamesRule , uniqueInputFieldNamesRule + , providedRequiredInputFieldsRule -- Directives. , knownDirectiveNamesRule , directivesInValidLocationsRule @@ -646,7 +648,7 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas findDirectiveVariables (Directive _ arguments _) = mapArguments arguments mapArguments = Seq.fromList . mapMaybe findArgumentVariables mapDirectives = foldMap findDirectiveVariables - findArgumentVariables (Argument _ Node{ value = Variable value', ..} _) = + findArgumentVariables (Argument _ Node{ node = Variable value', ..} _) = Just (value', [location]) findArgumentVariables _ = Nothing makeError operationName (variableName, locations') = Error @@ -680,12 +682,12 @@ uniqueInputFieldNamesRule :: forall m. Rule m uniqueInputFieldNamesRule = ValueRule (const $ lift . go) (const $ lift . constGo) where - go (Object fields) = filterFieldDuplicates fields + go (Node (Object fields) _) = filterFieldDuplicates fields go _ = mempty filterFieldDuplicates fields = filterDuplicates getFieldName "input field" fields getFieldName (ObjectField fieldName _ location') = (fieldName, location') - constGo (ConstObject fields) = filterFieldDuplicates fields + constGo (Node (ConstObject fields) _) = filterFieldDuplicates fields constGo _ = mempty -- | The target field of a field selection must be defined on the scoped type of @@ -848,11 +850,11 @@ knownDirectiveNamesRule = DirectivesRule $ const $ \directives' -> do knownInputFieldNamesRule :: Rule m knownInputFieldNamesRule = ValueRule go constGo where - go (Just valueType) (Object inputFields) + go (Just valueType) (Node (Object inputFields) _) | In.InputObjectBaseType objectType <- valueType = lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields go _ _ = lift mempty - constGo (Just valueType) (ConstObject inputFields) + constGo (Just valueType) (Node (ConstObject inputFields) _) | In.InputObjectBaseType objectType <- valueType = lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields constGo _ _ = lift mempty @@ -915,13 +917,6 @@ providedRequiredArgumentsRule = ArgumentsRule fieldRule directiveRule let forEach = go (directiveMessage directiveName) arguments location' in lift $ HashMap.foldrWithKey forEach Seq.empty definitions _ -> lift mempty - inputTypeName (In.ScalarBaseType (Definition.ScalarType typeName _)) = - typeName - inputTypeName (In.EnumBaseType (Definition.EnumType typeName _ _)) = - typeName - inputTypeName (In.InputObjectBaseType (In.InputObjectType typeName _ _)) = - typeName - inputTypeName (In.ListBaseType listType) = inputTypeName listType go makeMessage arguments location' argumentName argumentType errors | In.Argument _ type' optionalValue <- argumentType , In.isNonNullType type' @@ -956,3 +951,49 @@ providedRequiredArgumentsRule = ArgumentsRule fieldRule directiveRule , Text.unpack typeName , "\" is required, but it was not provided." ] + +inputTypeName :: In.Type -> Text +inputTypeName (In.ScalarBaseType (Definition.ScalarType typeName _)) = typeName +inputTypeName (In.EnumBaseType (Definition.EnumType typeName _ _)) = typeName +inputTypeName (In.InputObjectBaseType (In.InputObjectType typeName _ _)) = + typeName +inputTypeName (In.ListBaseType listType) = inputTypeName listType + +-- | Input object fields may be required. Much like a field may have required +-- arguments, an input object may have required fields. An input field is +-- required if it has a non‐null type and does not have a default value. +-- Otherwise, the input object field is optional. +providedRequiredInputFieldsRule :: Rule m +providedRequiredInputFieldsRule = ValueRule go constGo + where + go (Just valueType) (Node (Object inputFields) location') + | In.InputObjectBaseType objectType <- valueType + , In.InputObjectType objectTypeName _ fieldDefinitions <- objectType + = lift + $ Seq.fromList + $ HashMap.elems + $ flip HashMap.mapMaybeWithKey fieldDefinitions + $ forEach inputFields objectTypeName location' + go _ _ = lift mempty + constGo _ _ = lift mempty + forEach inputFields typeName location' definitionName fieldDefinition + | In.InputField _ inputType optionalValue <- fieldDefinition + , In.isNonNullType inputType + , isNothing optionalValue + , isNothingOrNull $ find (lookupField definitionName) inputFields = + Just $ makeError definitionName typeName location' + | otherwise = Nothing + isNothingOrNull (Just (ObjectField _ (Node Null _) _)) = True + isNothingOrNull x = isNothing x + lookupField needle (ObjectField fieldName _ _) = needle == fieldName + makeError fieldName typeName location' = Error + { message = errorMessage fieldName typeName + , locations = [location'] + } + errorMessage fieldName typeName = concat + [ "Input field \"" + , Text.unpack fieldName + , "\" of type \"" + , Text.unpack typeName + , "\" is required, but it was not provided." + ] diff --git a/src/Language/GraphQL/Validate/Validation.hs b/src/Language/GraphQL/Validate/Validation.hs index 32a454e..7ffab10 100644 --- a/src/Language/GraphQL/Validate/Validation.hs +++ b/src/Language/GraphQL/Validate/Validation.hs @@ -48,7 +48,7 @@ data Rule m | ArgumentsRule (Maybe (Out.Type m) -> Field -> RuleT m) (Directive -> RuleT m) | DirectivesRule (DirectiveLocation -> [Directive] -> RuleT m) | VariablesRule ([VariableDefinition] -> RuleT m) - | ValueRule (Maybe In.Type -> Value -> RuleT m) (Maybe In.Type -> ConstValue -> RuleT m) + | ValueRule (Maybe In.Type -> Node Value -> RuleT m) (Maybe In.Type -> Node ConstValue -> RuleT m) -- | Monad transformer used by the rules. type RuleT m = ReaderT (Validation m) Seq Error diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs index 60e717a..92b3001 100644 --- a/tests/Language/GraphQL/ValidateSpec.hs +++ b/tests/Language/GraphQL/ValidateSpec.hs @@ -594,7 +594,7 @@ spec = it "rejects undefined input object fields" $ let queryString = [r| { - findDog(complex: { favoriteCookieFlavor: "Bacon" }) { + findDog(complex: { favoriteCookieFlavor: "Bacon", name: "Jack" }) { name } } @@ -620,3 +620,19 @@ spec = , locations = [AST.Location 2 21] } in validate queryString `shouldBe` [expected] + + it "rejects missing required input fields" $ + let queryString = [r| + { + findDog(complex: { name: null }) { + name + } + } + |] + expected = Error + { message = + "Input field \"name\" of type \"DogData\" is required, \ + \but it was not provided." + , locations = [AST.Location 3 34] + } + in validate queryString `shouldBe` [expected]