Просмотр исходного кода

fix(core): Fix issues caused by f235249f

The fix f235249f inadvertently broke transactions across field
resolvers in all databases apart from SQLite. This commit solves that.
Michael Bromley 1 год назад
Родитель
Сommit
5a4299a114

+ 2 - 2
packages/asset-server-plugin/src/common.ts

@@ -1,4 +1,4 @@
-import { REQUEST_CONTEXT_KEY } from '@vendure/core/dist/common/constants';
+import { internal_getRequestContext } from '@vendure/core';
 import { Request } from 'express';
 
 import { AssetServerOptions, ImageTransformFormat } from './types';
@@ -18,7 +18,7 @@ export function getAssetUrlPrefixFn(options: AssetServerOptions) {
     }
     if (typeof assetUrlPrefix === 'function') {
         return (request: Request, identifier: string) => {
-            const ctx = (request as any)[REQUEST_CONTEXT_KEY];
+            const ctx = internal_getRequestContext(request);
             return assetUrlPrefix(ctx, identifier);
         };
     }

+ 13 - 0
packages/core/src/api/common/is-field-resolver.ts

@@ -0,0 +1,13 @@
+import { GraphQLResolveInfo } from 'graphql';
+
+/**
+ * Returns true is this guard is being called on a FieldResolver, i.e. not a top-level
+ * Query or Mutation resolver.
+ */
+export function isFieldResolver(info?: GraphQLResolveInfo): boolean {
+    if (!info) {
+        return false;
+    }
+    const parentType = info?.parentType?.name;
+    return parentType !== 'Query' && parentType !== 'Mutation' && parentType !== 'Subscription';
+}

+ 70 - 12
packages/core/src/api/common/request-context.ts

@@ -4,8 +4,13 @@ import { ID, JsonCompatible } from '@vendure/common/lib/shared-types';
 import { isObject } from '@vendure/common/lib/shared-utils';
 import { Request } from 'express';
 import { TFunction } from 'i18next';
+import { EntityManager } from 'typeorm';
 
-import { REQUEST_CONTEXT_KEY, REQUEST_CONTEXT_MAP_KEY } from '../../common/constants';
+import {
+    REQUEST_CONTEXT_KEY,
+    REQUEST_CONTEXT_MAP_KEY,
+    TRANSACTION_MANAGER_KEY,
+} from '../../common/constants';
 import { idsAreEqual } from '../../common/utils';
 import { CachedSession } from '../../config/session-cache/session-cache-strategy';
 import { Channel } from '../../entity/channel/channel.entity';
@@ -22,6 +27,32 @@ export type SerializedRequestContext = {
     _authorizedAsOwnerOnly: boolean;
 };
 
+/**
+ * This object is used to store the RequestContext on the Express Request object.
+ */
+interface RequestContextStore {
+    /**
+     * This is the default RequestContext for the handler.
+     */
+    default: RequestContext;
+    /**
+     * If a transaction is started, the resulting RequestContext is stored here.
+     * This RequestContext will have a transaction manager attached via the
+     * TRANSACTION_MANAGER_KEY symbol.
+     *
+     * When a transaction is started, the TRANSACTION_MANAGER_KEY symbol is added to the RequestContext
+     * object. This is then detected inside the {@link internal_setRequestContext} function and the
+     * RequestContext object is stored in the RequestContextStore under the withTransactionManager key.
+     */
+    withTransactionManager?: RequestContext;
+}
+
+interface RequestWithStores extends Request {
+    // eslint-disable-next-line @typescript-eslint/ban-types
+    [REQUEST_CONTEXT_MAP_KEY]?: Map<Function, RequestContextStore>;
+    [REQUEST_CONTEXT_KEY]?: RequestContextStore;
+}
+
 /**
  * @description
  * This function is used to set the {@link RequestContext} on the `req` object. This is the underlying
@@ -42,23 +73,39 @@ export type SerializedRequestContext = {
  * We named it this way to discourage usage outside the framework internals.
  */
 export function internal_setRequestContext(
-    req: Request,
+    req: RequestWithStores,
     ctx: RequestContext,
     executionContext?: ExecutionContext,
 ) {
     // If we have access to the `ExecutionContext`, it means we are able to bind
     // the `ctx` object to the specific "handler", i.e. the resolver function (for GraphQL)
     // or controller (for REST).
+    let item: RequestContextStore | undefined;
     if (executionContext && typeof executionContext.getHandler === 'function') {
         // eslint-disable-next-line @typescript-eslint/ban-types
-        const map: Map<Function, RequestContext> = (req as any)[REQUEST_CONTEXT_MAP_KEY] || new Map();
-        map.set(executionContext.getHandler(), ctx);
+        const map = req[REQUEST_CONTEXT_MAP_KEY] || new Map();
+        item = map.get(executionContext.getHandler());
+        const ctxHasTransaction = Object.getOwnPropertySymbols(ctx).includes(TRANSACTION_MANAGER_KEY);
+        if (item) {
+            item.default = item.default ?? ctx;
+            if (ctxHasTransaction) {
+                item.withTransactionManager = ctx;
+            }
+        } else {
+            item = {
+                default: ctx,
+                withTransactionManager: ctxHasTransaction ? ctx : undefined,
+            };
+        }
+        map.set(executionContext.getHandler(), item);
 
-        (req as any)[REQUEST_CONTEXT_MAP_KEY] = map;
+        req[REQUEST_CONTEXT_MAP_KEY] = map;
     }
     // We also bind to a shared key so that we can access the `ctx` object
     // later even if we don't have a reference to the `ExecutionContext`
-    (req as any)[REQUEST_CONTEXT_KEY] = ctx;
+    req[REQUEST_CONTEXT_KEY] = item ?? {
+        default: ctx,
+    };
 }
 
 /**
@@ -67,20 +114,31 @@ export function internal_setRequestContext(
  * for more details on this mechanism.
  */
 export function internal_getRequestContext(
-    req: Request,
+    req: RequestWithStores,
     executionContext?: ExecutionContext,
 ): RequestContext {
+    let item: RequestContextStore | undefined;
     if (executionContext && typeof executionContext.getHandler === 'function') {
         // eslint-disable-next-line @typescript-eslint/ban-types
-        const map: Map<Function, RequestContext> | undefined = (req as any)[REQUEST_CONTEXT_MAP_KEY];
-        const ctx = map?.get(executionContext.getHandler());
+        const map = req[REQUEST_CONTEXT_MAP_KEY];
+        item = map?.get(executionContext.getHandler());
         // If we have a ctx associated with the current handler (resolver function), we
         // return it. Otherwise, we fall back to the shared key which will be there.
-        if (ctx) {
-            return ctx;
+        if (item) {
+            return item.withTransactionManager || item.default;
         }
     }
-    return (req as any)[REQUEST_CONTEXT_KEY];
+    if (!item) {
+        item = req[REQUEST_CONTEXT_KEY] as RequestContextStore;
+    }
+    const transactionalCtx =
+        item?.withTransactionManager &&
+        ((item.withTransactionManager as any)[TRANSACTION_MANAGER_KEY] as EntityManager | undefined)
+            ?.queryRunner?.isReleased === false
+            ? item.withTransactionManager
+            : undefined;
+
+    return transactionalCtx || item.default;
 }
 
 /**

+ 4 - 5
packages/core/src/api/config/generate-resolvers.ts

@@ -3,7 +3,6 @@ import { StockMovementType } from '@vendure/common/lib/generated-types';
 import { GraphQLSchema } from 'graphql';
 import { GraphQLDateTime, GraphQLJSON } from 'graphql-scalars';
 
-import { REQUEST_CONTEXT_KEY } from '../../common/constants';
 import { InternalServerError } from '../../common/error/errors';
 import {
     adminErrorOperationTypeResolvers,
@@ -18,7 +17,7 @@ import { Region } from '../../entity/region/region.entity';
 import { getPluginAPIExtensions } from '../../plugin/plugin-metadata';
 import { CustomFieldRelationResolverService } from '../common/custom-field-relation-resolver.service';
 import { ApiType } from '../common/get-api-type';
-import { RequestContext } from '../common/request-context';
+import { internal_getRequestContext } from '../common/request-context';
 import { userHasPermissionsOnCustomField } from '../common/user-has-permissions-on-custom-field';
 
 import { getCustomFieldsConfigWithoutInterfaces } from './get-custom-fields-config-without-interfaces';
@@ -206,7 +205,7 @@ function generateCustomFieldRelationResolvers(
             let resolver: IFieldResolver<any, any>;
             if (isRelationalType(fieldDef)) {
                 resolver = async (source: any, args: any, context: any) => {
-                    const ctx: RequestContext = context.req[REQUEST_CONTEXT_KEY];
+                    const ctx = internal_getRequestContext(context.req);
                     if (!userHasPermissionsOnCustomField(ctx, fieldDef)) {
                         return null;
                     }
@@ -235,7 +234,7 @@ function generateCustomFieldRelationResolvers(
                 };
             } else {
                 resolver = async (source: any, args: any, context: any) => {
-                    const ctx: RequestContext = context.req[REQUEST_CONTEXT_KEY];
+                    const ctx = internal_getRequestContext(context.req);
                     if (!userHasPermissionsOnCustomField(ctx, fieldDef)) {
                         return null;
                     }
@@ -271,7 +270,7 @@ function generateCustomFieldRelationResolvers(
 
 function getCustomScalars(configService: ConfigService, apiType: 'admin' | 'shop') {
     return getPluginAPIExtensions(configService.plugins, apiType)
-        .map(e => (typeof e.scalars === 'function' ? e.scalars() : e.scalars ?? {}))
+        .map(e => (typeof e.scalars === 'function' ? e.scalars() : (e.scalars ?? {})))
         .reduce(
             (all, scalarMap) => ({
                 ...all,

+ 7 - 9
packages/core/src/api/decorators/request-context.decorator.ts

@@ -1,5 +1,7 @@
-import { ContextType, createParamDecorator, ExecutionContext } from '@nestjs/common';
+import { createParamDecorator, ExecutionContext } from '@nestjs/common';
 
+import { isFieldResolver } from '../common/is-field-resolver';
+import { parseContext } from '../common/parse-context';
 import { internal_getRequestContext } from '../common/request-context';
 
 /**
@@ -18,12 +20,8 @@ import { internal_getRequestContext } from '../common/request-context';
  * @docsCategory request
  * @docsPage Ctx Decorator
  */
-export const Ctx = createParamDecorator((data, ctx: ExecutionContext) => {
-    if (ctx.getType<ContextType | 'graphql'>() === 'graphql') {
-        // GraphQL request
-        return internal_getRequestContext(ctx.getArgByIndex(2).req, ctx);
-    } else {
-        // REST request
-        return internal_getRequestContext(ctx.switchToHttp().getRequest(), ctx);
-    }
+export const Ctx = createParamDecorator((data, executionContext: ExecutionContext) => {
+    const context = parseContext(executionContext);
+    const handlerIsFieldResolver = context.isGraphQL && isFieldResolver(context.info);
+    return internal_getRequestContext(context.req, handlerIsFieldResolver ? undefined : executionContext);
 });

+ 4 - 16
packages/core/src/api/middleware/auth-guard.ts

@@ -2,7 +2,6 @@ import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common';
 import { Reflector } from '@nestjs/core';
 import { Permission } from '@vendure/common/lib/generated-types';
 import { Request, Response } from 'express';
-import { GraphQLResolveInfo } from 'graphql';
 
 import { ForbiddenError } from '../../common/error/errors';
 import { ConfigService } from '../../config/config.service';
@@ -14,6 +13,7 @@ import { ChannelService } from '../../service/services/channel.service';
 import { CustomerService } from '../../service/services/customer.service';
 import { SessionService } from '../../service/services/session.service';
 import { extractSessionToken } from '../common/extract-session-token';
+import { isFieldResolver } from '../common/is-field-resolver';
 import { parseContext } from '../common/parse-context';
 import {
     internal_getRequestContext,
@@ -47,16 +47,16 @@ export class AuthGuard implements CanActivate {
 
     async canActivate(context: ExecutionContext): Promise<boolean> {
         const { req, res, info } = parseContext(context);
-        const isFieldResolver = this.isFieldResolver(info);
+        const targetIsFieldResolver = isFieldResolver(info);
         const permissions = this.reflector.get<Permission[]>(PERMISSIONS_METADATA_KEY, context.getHandler());
-        if (isFieldResolver && !permissions) {
+        if (targetIsFieldResolver && !permissions) {
             return true;
         }
         const authDisabled = this.configService.authOptions.disableAuth;
         const isPublic = !!permissions && permissions.includes(Permission.Public);
         const hasOwnerPermission = !!permissions && permissions.includes(Permission.Owner);
         let requestContext: RequestContext;
-        if (isFieldResolver) {
+        if (targetIsFieldResolver) {
             requestContext = internal_getRequestContext(req);
         } else {
             const session = await this.getSession(req, res, hasOwnerPermission);
@@ -168,16 +168,4 @@ export class AuthGuard implements CanActivate {
         }
         return serializedSession;
     }
-
-    /**
-     * Returns true is this guard is being called on a FieldResolver, i.e. not a top-level
-     * Query or Mutation resolver.
-     */
-    private isFieldResolver(info?: GraphQLResolveInfo): boolean {
-        if (!info) {
-            return false;
-        }
-        const parentType = info?.parentType?.name;
-        return parentType !== 'Query' && parentType !== 'Mutation' && parentType !== 'Subscription';
-    }
 }

+ 6 - 5
packages/core/src/api/middleware/validate-custom-fields-interceptor.ts

@@ -1,7 +1,6 @@
 import { CallHandler, ExecutionContext, Injectable, NestInterceptor } from '@nestjs/common';
 import { ModuleRef } from '@nestjs/core';
 import { GqlExecutionContext } from '@nestjs/graphql';
-import { LanguageCode } from '@vendure/common/lib/generated-types';
 import { getGraphQlInputName } from '@vendure/common/lib/shared-utils';
 import {
     GraphQLInputType,
@@ -12,12 +11,11 @@ import {
     TypeNode,
 } from 'graphql';
 
-import { REQUEST_CONTEXT_KEY } from '../../common/constants';
 import { Injector } from '../../common/injector';
 import { ConfigService } from '../../config/config.service';
 import { CustomFieldConfig, CustomFields } from '../../config/custom-field/custom-field-types';
 import { parseContext } from '../common/parse-context';
-import { RequestContext } from '../common/request-context';
+import { internal_getRequestContext, RequestContext } from '../common/request-context';
 import { validateCustomFieldValue } from '../common/validate-custom-field-value';
 
 /**
@@ -29,7 +27,10 @@ import { validateCustomFieldValue } from '../common/validate-custom-field-value'
 export class ValidateCustomFieldsInterceptor implements NestInterceptor {
     private readonly inputsWithCustomFields: Set<string>;
 
-    constructor(private configService: ConfigService, private moduleRef: ModuleRef) {
+    constructor(
+        private configService: ConfigService,
+        private moduleRef: ModuleRef,
+    ) {
         this.inputsWithCustomFields = Object.keys(configService.customFields).reduce((inputs, entityName) => {
             inputs.add(`Create${entityName}Input`);
             inputs.add(`Update${entityName}Input`);
@@ -45,7 +46,7 @@ export class ValidateCustomFieldsInterceptor implements NestInterceptor {
             const gqlExecutionContext = GqlExecutionContext.create(context);
             const { operation, schema } = parsedContext.info;
             const variables = gqlExecutionContext.getArgs();
-            const ctx: RequestContext = (parsedContext.req as any)[REQUEST_CONTEXT_KEY];
+            const ctx = internal_getRequestContext(parsedContext.req);
 
             if (operation.operation === 'mutation') {
                 const inputTypeNames = this.getArgumentMap(operation, schema);

+ 1 - 5
packages/core/src/api/resolvers/entity/product-variant-entity.resolver.ts

@@ -175,11 +175,7 @@ export class ProductVariantAdminEntityResolver {
     }
 
     @ResolveField()
-    async stockOnHand(
-        @Ctx() ctx: RequestContext,
-        @Parent() productVariant: ProductVariant,
-        @Args() args: { options: StockMovementListOptions },
-    ): Promise<number> {
+    async stockOnHand(@Ctx() ctx: RequestContext, @Parent() productVariant: ProductVariant): Promise<number> {
         const { stockOnHand } = await this.stockLevelService.getAvailableStock(ctx, productVariant.id);
         return stockOnHand;
     }

+ 9 - 7
packages/core/src/connection/transaction-wrapper.ts

@@ -1,6 +1,6 @@
-import { from, lastValueFrom, Observable, of } from 'rxjs';
+import { from, lastValueFrom, Observable } from 'rxjs';
 import { retryWhen, take, tap } from 'rxjs/operators';
-import { Connection, EntityManager, QueryRunner } from 'typeorm';
+import { DataSource, EntityManager, QueryRunner } from 'typeorm';
 import { TransactionAlreadyStartedError } from 'typeorm/error/TransactionAlreadyStartedError';
 
 import { RequestContext } from '../api/common/request-context';
@@ -28,13 +28,13 @@ export class TransactionWrapper {
         work: (ctx: RequestContext) => Observable<T> | Promise<T>,
         mode: TransactionMode,
         isolationLevel: TransactionIsolationLevel | undefined,
-        connection: Connection,
+        connection: DataSource,
     ): Promise<T> {
         // Copy to make sure original context will remain valid after transaction completes
         const ctx = originalCtx.copy();
 
         const entityManager: EntityManager | undefined = (ctx as any)[TRANSACTION_MANAGER_KEY];
-        const queryRunner = entityManager ?.queryRunner || connection.createQueryRunner();
+        const queryRunner = entityManager?.queryRunner || connection.createQueryRunner();
 
         if (mode === 'auto') {
             await this.startTransaction(queryRunner, isolationLevel);
@@ -67,8 +67,7 @@ export class TransactionWrapper {
             }
             throw error;
         } finally {
-            if (!queryRunner.isTransactionActive
-                && queryRunner.isReleased === false) {
+            if (!queryRunner.isTransactionActive && queryRunner.isReleased === false) {
                 // There is a check for an active transaction
                 // because this could be a nested transaction (savepoint).
 
@@ -81,7 +80,10 @@ export class TransactionWrapper {
      * Attempts to start a DB transaction, with retry logic in the case that a transaction
      * is already started for the connection (which is mainly a problem with SQLite/Sql.js)
      */
-    private async startTransaction(queryRunner: QueryRunner, isolationLevel: TransactionIsolationLevel | undefined) {
+    private async startTransaction(
+        queryRunner: QueryRunner,
+        isolationLevel: TransactionIsolationLevel | undefined,
+    ) {
         const maxRetries = 25;
         let attempts = 0;
         let lastError: any;