Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forbid @skip and @include directives in subscription root selection #3974

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 51 additions & 9 deletions src/execution/collectFields.ts
Expand Up @@ -5,6 +5,7 @@ import { isSameSet } from '../jsutils/isSameSet.js';
import type { ObjMap } from '../jsutils/ObjMap.js';

import type {
DirectiveNode,
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
Expand All @@ -26,7 +27,7 @@ import type { GraphQLSchema } from '../type/schema.js';

import { typeFromAST } from '../utilities/typeFromAST.js';

import { getDirectiveValues } from './values.js';
import { getArgumentValues, getDirectiveValues } from './values.js';

export interface DeferUsage {
label: string | undefined;
Expand Down Expand Up @@ -60,6 +61,7 @@ export interface CollectFieldsResult {
groupedFieldSet: GroupedFieldSet;
newGroupedFieldSetDetails: Map<DeferUsageSet, GroupedFieldSetDetails>;
newDeferUsages: ReadonlyArray<DeferUsage>;
forbiddenDirectiveInstances: ReadonlyArray<DirectiveNode>;
}

interface CollectFieldsContext {
Expand All @@ -72,6 +74,7 @@ interface CollectFieldsContext {
fieldsByTarget: Map<Target, AccumulatorMap<string, FieldNode>>;
newDeferUsages: Array<DeferUsage>;
visitedFragmentNames: Set<string>;
forbiddenDirectiveInstances: Array<DirectiveNode>;
}

/**
Expand Down Expand Up @@ -100,16 +103,28 @@ export function collectFields(
targetsByKey: new Map(),
newDeferUsages: [],
visitedFragmentNames: new Set(),
forbiddenDirectiveInstances: [],
};

collectFieldsImpl(context, operation.selectionSet);

return {
...buildGroupedFieldSets(context.targetsByKey, context.fieldsByTarget),
newDeferUsages: context.newDeferUsages,
forbiddenDirectiveInstances: context.forbiddenDirectiveInstances,
};
}

/**
* This variable is the empty variables used during the validation phase (where
* no variables exist) for field collection; if a `@skip` or `@include`
* directive is ever seen when `variableValues` is set to this, it should
* throw.
*/
export const VALIDATION_PHASE_EMPTY_VARIABLES: {
[variable: string]: any;
} = Object.freeze(Object.create(null));

/**
* Given an array of field nodes, collects all of the subfields of the passed
* in fields, and returns them at the end.
Expand Down Expand Up @@ -139,6 +154,7 @@ export function collectSubfields(
targetsByKey: new Map(),
newDeferUsages: [],
visitedFragmentNames: new Set(),
forbiddenDirectiveInstances: [],
};

for (const fieldDetails of fieldGroup.fields) {
Expand All @@ -155,6 +171,7 @@ export function collectSubfields(
fieldGroup.targets,
),
newDeferUsages: context.newDeferUsages,
forbiddenDirectiveInstances: context.forbiddenDirectiveInstances,
};
}

Expand All @@ -179,7 +196,7 @@ function collectFieldsImpl(
for (const selection of selectionSet.selections) {
switch (selection.kind) {
case Kind.FIELD: {
if (!shouldIncludeNode(variableValues, selection)) {
if (!shouldIncludeNode(context, variableValues, selection)) {
continue;
}
const key = getFieldEntryKey(selection);
Expand All @@ -200,7 +217,7 @@ function collectFieldsImpl(
}
case Kind.INLINE_FRAGMENT: {
if (
!shouldIncludeNode(variableValues, selection) ||
!shouldIncludeNode(context, variableValues, selection) ||
!doesFragmentConditionMatch(schema, selection, runtimeType)
) {
continue;
Expand Down Expand Up @@ -232,7 +249,7 @@ function collectFieldsImpl(
case Kind.FRAGMENT_SPREAD: {
const fragName = selection.name.value;

if (!shouldIncludeNode(variableValues, selection)) {
if (!shouldIncludeNode(context, variableValues, selection)) {
continue;
}

Expand Down Expand Up @@ -304,19 +321,44 @@ function getDeferValues(
* directives, where `@skip` has higher precedence than `@include`.
*/
function shouldIncludeNode(
context: CollectFieldsContext,
variableValues: { [variable: string]: unknown },
node: FragmentSpreadNode | FieldNode | InlineFragmentNode,
): boolean {
const skip = getDirectiveValues(GraphQLSkipDirective, node, variableValues);
const skipDirectiveNode = node.directives?.find(
(directive) => directive.name.value === GraphQLSkipDirective.name,
);
if (
skipDirectiveNode &&
variableValues === VALIDATION_PHASE_EMPTY_VARIABLES
) {
context.forbiddenDirectiveInstances.push(skipDirectiveNode);
return false;
}
const skip = skipDirectiveNode
? getArgumentValues(GraphQLSkipDirective, skipDirectiveNode, variableValues)
: undefined;
if (skip?.if === true) {
return false;
}

const include = getDirectiveValues(
GraphQLIncludeDirective,
node,
variableValues,
const includeDirectiveNode = node.directives?.find(
(directive) => directive.name.value === GraphQLIncludeDirective.name,
);
if (
includeDirectiveNode &&
variableValues === VALIDATION_PHASE_EMPTY_VARIABLES
) {
context.forbiddenDirectiveInstances.push(includeDirectiveNode);
return false;
}
const include = includeDirectiveNode
? getArgumentValues(
GraphQLIncludeDirective,
includeDirectiveNode,
variableValues,
)
: undefined;
if (include?.if === false) {
return false;
}
Expand Down
42 changes: 42 additions & 0 deletions src/validation/__tests__/SingleFieldSubscriptionsRule-test.ts
Expand Up @@ -286,6 +286,48 @@ describe('Validate: Subscriptions with single field', () => {
]);
});

it('fails with @skip or @include directive', () => {
expectErrors(`
subscription RequiredRuntimeValidation($bool: Boolean!) {
newMessage @include(if: $bool) {
body
sender
}
disallowedSecondRootField @skip(if: $bool)
}
`).toDeepEqual([
{
message:
'Subscription "RequiredRuntimeValidation" must not use `@skip` or `@include` directives in the top level selection.',
locations: [
{ line: 3, column: 20 },
{ line: 7, column: 35 },
],
},
]);
});

it('fails with @skip or @include directive in anonymous subscription', () => {
expectErrors(`
subscription ($bool: Boolean!) {
newMessage @include(if: $bool) {
body
sender
}
disallowedSecondRootField @skip(if: $bool)
}
`).toDeepEqual([
{
message:
'Anonymous Subscription must not use `@skip` or `@include` directives in the top level selection.',
locations: [
{ line: 3, column: 20 },
{ line: 7, column: 35 },
],
},
]);
});

it('skips if not subscription type', () => {
const emptySchema = buildSchema(`
type Query {
Expand Down
38 changes: 26 additions & 12 deletions src/validation/rules/SingleFieldSubscriptionsRule.ts
Expand Up @@ -11,7 +11,10 @@ import { Kind } from '../../language/kinds.js';
import type { ASTVisitor } from '../../language/visitor.js';

import type { FieldGroup } from '../../execution/collectFields.js';
import { collectFields } from '../../execution/collectFields.js';
import {
collectFields,
VALIDATION_PHASE_EMPTY_VARIABLES,
} from '../../execution/collectFields.js';

import type { ValidationContext } from '../ValidationContext.js';

Expand All @@ -23,7 +26,8 @@ function toNodes(fieldGroup: FieldGroup): ReadonlyArray<FieldNode> {
* Subscriptions must only include a non-introspection field.
*
* A GraphQL subscription is valid only if it contains a single root field and
* that root field is not an introspection field.
* that root field is not an introspection field. `@skip` and `@include`
* directives are forbidden.
*
* See https://spec.graphql.org/draft/#sec-Single-root-field
*/
Expand All @@ -37,23 +41,33 @@ export function SingleFieldSubscriptionsRule(
const subscriptionType = schema.getSubscriptionType();
if (subscriptionType) {
const operationName = node.name ? node.name.value : null;
const variableValues: {
[variable: string]: any;
} = Object.create(null);
const variableValues = VALIDATION_PHASE_EMPTY_VARIABLES;
const document = context.getDocument();
const fragments: ObjMap<FragmentDefinitionNode> = Object.create(null);
for (const definition of document.definitions) {
if (definition.kind === Kind.FRAGMENT_DEFINITION) {
fragments[definition.name.value] = definition;
}
}
const { groupedFieldSet } = collectFields(
schema,
fragments,
variableValues,
subscriptionType,
node,
);
const { groupedFieldSet, forbiddenDirectiveInstances } =
collectFields(
schema,
fragments,
variableValues,
subscriptionType,
node,
);
if (forbiddenDirectiveInstances.length > 0) {
context.reportError(
new GraphQLError(
operationName != null
? `Subscription "${operationName}" must not use \`@skip\` or \`@include\` directives in the top level selection.`
: 'Anonymous Subscription must not use `@skip` or `@include` directives in the top level selection.',
{ nodes: forbiddenDirectiveInstances },
),
);
return;
}
if (groupedFieldSet.size > 1) {
const fieldGroups = [...groupedFieldSet.values()];
const extraFieldGroups = fieldGroups.slice(1);
Expand Down