Validate non-nullable values inside lists

This commit is contained in:
Eugen Wissner 2021-05-06 22:23:16 +02:00
parent 2220f0ca56
commit 5a5f265fe4
5 changed files with 49 additions and 32 deletions

View File

@ -18,6 +18,7 @@ and this project adheres to
validation traverser calls it on all list items. validation traverser calls it on all list items.
- `valuesOfCorrectTypeRule` doesn't check objects recursively since the - `valuesOfCorrectTypeRule` doesn't check objects recursively since the
validation traverser calls it on all object properties. validation traverser calls it on all object properties.
- Validation of non-nullable values inside lists.
### Changed ### Changed
- `AST.Document.Value.List` and `AST.Document.ConstValue.ConstList` contain - `AST.Document.Value.List` and `AST.Document.ConstValue.ConstList` contain

View File

@ -314,9 +314,6 @@ constValue (Validation.ValueRule _ rule) valueType = go valueType
go inputObjectType value'@(Full.Node (Full.ConstObject fields) _) go inputObjectType value'@(Full.Node (Full.ConstObject fields) _)
= foldMap (forEach inputObjectType) (Seq.fromList fields) = foldMap (forEach inputObjectType) (Seq.fromList fields)
|> rule inputObjectType value' |> rule inputObjectType value'
go listType value'@(Full.Node (Full.ConstList values) _location)
= embedListLocation go listType values
|> rule listType value'
go anotherValue value' = pure $ rule anotherValue value' go anotherValue value' = pure $ rule anotherValue value'
forEach inputObjectType Full.ObjectField{value = value', ..} = forEach inputObjectType Full.ObjectField{value = value', ..} =
go (valueTypeByName name inputObjectType) value' go (valueTypeByName name inputObjectType) value'
@ -420,19 +417,6 @@ argument rule argumentType (Full.Argument _ value' _) =
where where
valueType (In.Argument _ valueType' _) = valueType' valueType (In.Argument _ valueType' _) = valueType'
-- Applies a validation rule to each list value and merges returned errors.
embedListLocation :: forall a m
. (Maybe In.Type -> Full.Node a -> Seq m)
-> Maybe In.Type
-> [Full.Node a]
-> Seq m
embedListLocation go listType
= foldMap (go $ valueTypeFromList listType)
. Seq.fromList
where
valueTypeFromList (Just (In.ListBaseType baseType)) = Just baseType
valueTypeFromList _ = Nothing
value :: forall m value :: forall m
. Validation.Rule m . Validation.Rule m
-> Maybe In.Type -> Maybe In.Type
@ -443,9 +427,6 @@ value (Validation.ValueRule rule _) valueType = go valueType
go inputObjectType value'@(Full.Node (Full.Object fields) _) go inputObjectType value'@(Full.Node (Full.Object fields) _)
= foldMap (forEach inputObjectType) (Seq.fromList fields) = foldMap (forEach inputObjectType) (Seq.fromList fields)
|> rule inputObjectType value' |> rule inputObjectType value'
go listType value'@(Full.Node (Full.List values) _location)
= embedListLocation go listType values
|> rule listType value'
go anotherValue value' = pure $ rule anotherValue value' go anotherValue value' = pure $ rule anotherValue value'
forEach inputObjectType Full.ObjectField{value = value', ..} = forEach inputObjectType Full.ObjectField{value = value', ..} =
go (valueTypeByName name inputObjectType) value' go (valueTypeByName name inputObjectType) value'

View File

@ -1502,15 +1502,6 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case
hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False
hasNonNullVariableDefaultValue Nothing = False hasNonNullVariableDefaultValue Nothing = False
hasNonNullVariableDefaultValue _ = True hasNonNullVariableDefaultValue _ = True
unwrapInType (In.NonNullScalarType nonNullType) =
Just $ In.NamedScalarType nonNullType
unwrapInType (In.NonNullEnumType nonNullType) =
Just $ In.NamedEnumType nonNullType
unwrapInType (In.NonNullInputObjectType nonNullType) =
Just $ In.NamedInputObjectType nonNullType
unwrapInType (In.NonNullListType nonNullType) =
Just $ In.ListType nonNullType
unwrapInType _ = Nothing
makeError variableDefinition expectedType = makeError variableDefinition expectedType =
let Full.VariableDefinition variableName variableType _ location' = let Full.VariableDefinition variableName variableType _ location' =
variableDefinition variableDefinition
@ -1527,6 +1518,17 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case
, locations = [location'] , locations = [location']
} }
unwrapInType :: In.Type -> Maybe In.Type
unwrapInType (In.NonNullScalarType nonNullType) =
Just $ In.NamedScalarType nonNullType
unwrapInType (In.NonNullEnumType nonNullType) =
Just $ In.NamedEnumType nonNullType
unwrapInType (In.NonNullInputObjectType nonNullType) =
Just $ In.NamedInputObjectType nonNullType
unwrapInType (In.NonNullListType nonNullType) =
Just $ In.ListType nonNullType
unwrapInType _ = Nothing
-- | Literal values must be compatible with the type expected in the position -- | Literal values must be compatible with the type expected in the position
-- they are found as per the coercion rules. -- they are found as per the coercion rules.
-- --
@ -1586,8 +1588,8 @@ valuesOfCorrectTypeRule = ValueRule go constGo
| In.InputObjectType{} <- objectType | In.InputObjectType{} <- objectType
, Full.ConstObject{} <- node = mempty , Full.ConstObject{} <- node = mempty
check (In.ListBaseType listType) constValue@Full.Node{ .. } check (In.ListBaseType listType) constValue@Full.Node{ .. }
-- Skip, lists are checked recursively by the validation traverser. | Full.ConstList values <- node =
| Full.ConstList{} <- node = mempty foldMap (checkNull listType) values
| otherwise = check listType constValue | otherwise = check listType constValue
check inputType Full.Node{ .. } = pure $ Error check inputType Full.Node{ .. } = pure $ Error
{ message = concat { message = concat
@ -1599,3 +1601,20 @@ valuesOfCorrectTypeRule = ValueRule go constGo
] ]
, locations = [location] , locations = [location]
} }
checkNull inputType constValue =
let checkResult = check inputType constValue
in case null checkResult of
True
| Just unwrappedType <- unwrapInType inputType
, Full.Node{ node = Full.ConstNull, .. } <- constValue ->
pure $ Error
{ message = concat
[ "List of non-null values of type \""
, show unwrappedType
, "\" cannot contain null values."
]
, locations = [location]
}
| otherwise -> mempty
_ -> checkResult

View File

@ -1,4 +1,4 @@
resolver: lts-17.9 resolver: lts-17.10
packages: packages:
- . - .

View File

@ -847,7 +847,7 @@ spec =
} }
in validate queryString `shouldBe` [expected] in validate queryString `shouldBe` [expected]
context "providedRequiredArgumentsRule" $ context "providedRequiredArgumentsRule" $ do
it "checks for (non-)nullable arguments" $ it "checks for (non-)nullable arguments" $
let queryString = [r| let queryString = [r|
{ {
@ -944,3 +944,19 @@ spec =
, locations = [AST.Location 3 46] , locations = [AST.Location 3 46]
} }
in validate queryString `shouldBe` [expected] in validate queryString `shouldBe` [expected]
it "checks for required list members" $
let queryString = [r|
{
cat {
doesKnowCommands(catCommands: [null])
}
}
|]
expected = Error
{ message =
"List of non-null values of type \"CatCommand\" \
\cannot contain null values."
, locations = [AST.Location 4 54]
}
in validate queryString `shouldBe` [expected]