From 5a5f265fe4bf291c1ef58f5fe452f1e8c69c9ed6 Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Thu, 6 May 2021 22:23:16 +0200 Subject: [PATCH] Validate non-nullable values inside lists --- CHANGELOG.md | 1 + src/Language/GraphQL/Validate.hs | 19 --------- src/Language/GraphQL/Validate/Rules.hs | 41 ++++++++++++++------ stack.yaml | 2 +- tests/Language/GraphQL/Validate/RulesSpec.hs | 18 ++++++++- 5 files changed, 49 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2854191..ad90db9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to validation traverser calls it on all list items. - `valuesOfCorrectTypeRule` doesn't check objects recursively since the validation traverser calls it on all object properties. +- Validation of non-nullable values inside lists. ### Changed - `AST.Document.Value.List` and `AST.Document.ConstValue.ConstList` contain diff --git a/src/Language/GraphQL/Validate.hs b/src/Language/GraphQL/Validate.hs index 2e86e53..f929b98 100644 --- a/src/Language/GraphQL/Validate.hs +++ b/src/Language/GraphQL/Validate.hs @@ -314,9 +314,6 @@ constValue (Validation.ValueRule _ rule) valueType = go valueType go inputObjectType value'@(Full.Node (Full.ConstObject fields) _) = foldMap (forEach inputObjectType) (Seq.fromList fields) |> rule inputObjectType value' - go listType value'@(Full.Node (Full.ConstList values) _location) - = embedListLocation go listType values - |> rule listType value' go anotherValue value' = pure $ rule anotherValue value' forEach inputObjectType Full.ObjectField{value = value', ..} = go (valueTypeByName name inputObjectType) value' @@ -420,19 +417,6 @@ argument rule argumentType (Full.Argument _ value' _) = where valueType (In.Argument _ valueType' _) = valueType' --- Applies a validation rule to each list value and merges returned errors. -embedListLocation :: forall a m - . (Maybe In.Type -> Full.Node a -> Seq m) - -> Maybe In.Type - -> [Full.Node a] - -> Seq m -embedListLocation go listType - = foldMap (go $ valueTypeFromList listType) - . Seq.fromList - where - valueTypeFromList (Just (In.ListBaseType baseType)) = Just baseType - valueTypeFromList _ = Nothing - value :: forall m . Validation.Rule m -> Maybe In.Type @@ -443,9 +427,6 @@ value (Validation.ValueRule rule _) valueType = go valueType go inputObjectType value'@(Full.Node (Full.Object fields) _) = foldMap (forEach inputObjectType) (Seq.fromList fields) |> rule inputObjectType value' - go listType value'@(Full.Node (Full.List values) _location) - = embedListLocation go listType values - |> rule listType value' go anotherValue value' = pure $ rule anotherValue value' forEach inputObjectType Full.ObjectField{value = value', ..} = go (valueTypeByName name inputObjectType) value' diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index a6bc43b..46a14b7 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -1502,15 +1502,6 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False hasNonNullVariableDefaultValue Nothing = False hasNonNullVariableDefaultValue _ = True - unwrapInType (In.NonNullScalarType nonNullType) = - Just $ In.NamedScalarType nonNullType - unwrapInType (In.NonNullEnumType nonNullType) = - Just $ In.NamedEnumType nonNullType - unwrapInType (In.NonNullInputObjectType nonNullType) = - Just $ In.NamedInputObjectType nonNullType - unwrapInType (In.NonNullListType nonNullType) = - Just $ In.ListType nonNullType - unwrapInType _ = Nothing makeError variableDefinition expectedType = let Full.VariableDefinition variableName variableType _ location' = variableDefinition @@ -1527,6 +1518,17 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case , locations = [location'] } +unwrapInType :: In.Type -> Maybe In.Type +unwrapInType (In.NonNullScalarType nonNullType) = + Just $ In.NamedScalarType nonNullType +unwrapInType (In.NonNullEnumType nonNullType) = + Just $ In.NamedEnumType nonNullType +unwrapInType (In.NonNullInputObjectType nonNullType) = + Just $ In.NamedInputObjectType nonNullType +unwrapInType (In.NonNullListType nonNullType) = + Just $ In.ListType nonNullType +unwrapInType _ = Nothing + -- | Literal values must be compatible with the type expected in the position -- they are found as per the coercion rules. -- @@ -1586,8 +1588,8 @@ valuesOfCorrectTypeRule = ValueRule go constGo | In.InputObjectType{} <- objectType , Full.ConstObject{} <- node = mempty check (In.ListBaseType listType) constValue@Full.Node{ .. } - -- Skip, lists are checked recursively by the validation traverser. - | Full.ConstList{} <- node = mempty + | Full.ConstList values <- node = + foldMap (checkNull listType) values | otherwise = check listType constValue check inputType Full.Node{ .. } = pure $ Error { message = concat @@ -1599,3 +1601,20 @@ valuesOfCorrectTypeRule = ValueRule go constGo ] , locations = [location] } + checkNull inputType constValue = + let checkResult = check inputType constValue + in case null checkResult of + True + | Just unwrappedType <- unwrapInType inputType + , Full.Node{ node = Full.ConstNull, .. } <- constValue -> + pure $ Error + { message = concat + [ "List of non-null values of type \"" + , show unwrappedType + , "\" cannot contain null values." + ] + , locations = [location] + } + | otherwise -> mempty + _ -> checkResult + diff --git a/stack.yaml b/stack.yaml index 36e6de5..97b3334 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -resolver: lts-17.9 +resolver: lts-17.10 packages: - . diff --git a/tests/Language/GraphQL/Validate/RulesSpec.hs b/tests/Language/GraphQL/Validate/RulesSpec.hs index 02b14ae..f75aef6 100644 --- a/tests/Language/GraphQL/Validate/RulesSpec.hs +++ b/tests/Language/GraphQL/Validate/RulesSpec.hs @@ -847,7 +847,7 @@ spec = } in validate queryString `shouldBe` [expected] - context "providedRequiredArgumentsRule" $ + context "providedRequiredArgumentsRule" $ do it "checks for (non-)nullable arguments" $ let queryString = [r| { @@ -944,3 +944,19 @@ spec = , locations = [AST.Location 3 46] } in validate queryString `shouldBe` [expected] + + it "checks for required list members" $ + let queryString = [r| + { + cat { + doesKnowCommands(catCommands: [null]) + } + } + |] + expected = Error + { message = + "List of non-null values of type \"CatCommand\" \ + \cannot contain null values." + , locations = [AST.Location 4 54] + } + in validate queryString `shouldBe` [expected]