Bläddra i källkod

feat(core): Create Transaction decorator

Michael Bromley 5 år sedan
förälder
incheckning
4040089a61

+ 14 - 13
packages/core/e2e/database-transactions.e2e-spec.ts

@@ -2,12 +2,14 @@ import { Injectable } from '@nestjs/common';
 import { Args, Mutation, Query, Resolver } from '@nestjs/graphql';
 import {
     Administrator,
+    Ctx,
     InternalServerError,
     mergeConfig,
     NativeAuthenticationMethod,
     PluginCommonModule,
+    RequestContext,
+    Transaction,
     TransactionalConnection,
-    UnitOfWork,
     User,
     VendurePlugin,
 } from '@vendure/core';
@@ -22,8 +24,8 @@ import { TEST_SETUP_TIMEOUT_MS, testConfig } from '../../../e2e-common/test-conf
 class TestUserService {
     constructor(private connection: TransactionalConnection) {}
 
-    async createUser(identifier: string) {
-        const authMethod = await this.connection.getRepository(NativeAuthenticationMethod).save(
+    async createUser(ctx: RequestContext, identifier: string) {
+        const authMethod = await this.connection.getRepository(ctx, NativeAuthenticationMethod).save(
             new NativeAuthenticationMethod({
                 identifier,
                 passwordHash: 'abc',
@@ -45,12 +47,12 @@ class TestUserService {
 class TestAdminService {
     constructor(private connection: TransactionalConnection, private userService: TestUserService) {}
 
-    async createAdministrator(emailAddress: string, fail: boolean) {
-        const user = await this.userService.createUser(emailAddress);
+    async createAdministrator(ctx: RequestContext, emailAddress: string, fail: boolean) {
+        const user = await this.userService.createUser(ctx, emailAddress);
         if (fail) {
             throw new InternalServerError('Failed!');
         }
-        const admin = await this.connection.getRepository(Administrator).save(
+        const admin = await this.connection.getRepository(ctx, Administrator).save(
             new Administrator({
                 emailAddress,
                 user,
@@ -64,19 +66,18 @@ class TestAdminService {
 
 @Resolver()
 class TestResolver {
-    constructor(private uow: UnitOfWork, private testAdminService: TestAdminService) {}
+    constructor(private testAdminService: TestAdminService, private connection: TransactionalConnection) {}
 
     @Mutation()
-    createTestAdministrator(@Args() args: any) {
-        return this.uow.withTransaction(() => {
-            return this.testAdminService.createAdministrator(args.emailAddress, args.fail);
-        });
+    @Transaction
+    createTestAdministrator(@Ctx() ctx: RequestContext, @Args() args: any) {
+        return this.testAdminService.createAdministrator(ctx, args.emailAddress, args.fail);
     }
 
     @Query()
     async verify() {
-        const admins = await this.uow.getConnection().getRepository(Administrator).find();
-        const users = await this.uow.getConnection().getRepository(User).find();
+        const admins = await this.connection.getRepository(Administrator).find();
+        const users = await this.connection.getRepository(User).find();
         return {
             admins,
             users,

+ 9 - 3
packages/core/src/api/common/parse-context.ts

@@ -3,13 +3,19 @@ import { GqlExecutionContext } from '@nestjs/graphql';
 import { Request, Response } from 'express';
 import { GraphQLResolveInfo } from 'graphql';
 
+export type RestContext = { req: Request; res: Response; isGraphQL: false; info: undefined };
+export type GraphQLContext = {
+    req: Request;
+    res: Response;
+    isGraphQL: true;
+    info: GraphQLResolveInfo;
+};
+
 /**
  * Parses in the Nest ExecutionContext of the incoming request, accounting for both
  * GraphQL & REST requests.
  */
-export function parseContext(
-    context: ExecutionContext | ArgumentsHost,
-): { req: Request; res: Response; isGraphQL: boolean; info?: GraphQLResolveInfo } {
+export function parseContext(context: ExecutionContext | ArgumentsHost): RestContext | GraphQLContext {
     const graphQlContext = GqlExecutionContext.create(context as ExecutionContext);
     const restContext = GqlExecutionContext.create(context as ExecutionContext);
     const info = graphQlContext.getInfo();

+ 1 - 3
packages/core/src/api/common/request-context.service.ts

@@ -12,8 +12,6 @@ import { ChannelService } from '../../service/services/channel.service';
 import { getApiType } from './get-api-type';
 import { RequestContext } from './request-context';
 
-export const REQUEST_CONTEXT_KEY = 'vendureRequestContext';
-
 /**
  * Creates new RequestContext instances.
  */
@@ -79,7 +77,7 @@ export class RequestContextService {
         if (!user || !channel) {
             return false;
         }
-        const permissionsOnChannel = user.channelPermissions.find((c) => idsAreEqual(c.id, channel.id));
+        const permissionsOnChannel = user.channelPermissions.find(c => idsAreEqual(c.id, channel.id));
         if (permissionsOnChannel) {
             return this.arraysIntersect(permissionsOnChannel.permissions, permissions);
         }

+ 1 - 1
packages/core/src/api/decorators/request-context.decorator.ts

@@ -1,6 +1,6 @@
 import { ContextType, createParamDecorator, ExecutionContext } from '@nestjs/common';
 
-import { REQUEST_CONTEXT_KEY } from '../common/request-context.service';
+import { REQUEST_CONTEXT_KEY } from '../../common/constants';
 
 /**
  * @description

+ 15 - 0
packages/core/src/api/decorators/transaction.decorator.ts

@@ -0,0 +1,15 @@
+import { applyDecorators, UseInterceptors } from '@nestjs/common';
+
+import { TransactionInterceptor } from '../middleware/transaction-interceptor';
+
+/**
+ * @description
+ * Runs the decorated method in a TypeORM transaction. It works by creating a transctional
+ * QueryRunner which gets attached to the RequestContext object. When the RequestContext
+ * is the passed to the {@link TransactionalConnection} `getRepository()` method, this
+ * QueryRunner is used to execute the queries within this transaction.
+ *
+ * @docsCategory request
+ * @docsPage Decorators
+ */
+export const Transaction = applyDecorators(UseInterceptors(TransactionInterceptor));

+ 1 - 0
packages/core/src/api/index.ts

@@ -1,6 +1,7 @@
 export { ApiType } from './common/get-api-type';
 export * from './common/request-context';
 export * from './decorators/allow.decorator';
+export * from './decorators/transaction.decorator';
 export * from './decorators/api.decorator';
 export * from './decorators/request-context.decorator';
 export * from './resolvers/admin/search.resolver';

+ 2 - 1
packages/core/src/api/middleware/auth-guard.ts

@@ -3,13 +3,14 @@ import { Reflector } from '@nestjs/core';
 import { Permission } from '@vendure/common/lib/generated-types';
 import { Request, Response } from 'express';
 
+import { REQUEST_CONTEXT_KEY } from '../../common/constants';
 import { ForbiddenError } from '../../common/error/errors';
 import { ConfigService } from '../../config/config.service';
 import { CachedSession } from '../../config/session-cache/session-cache-strategy';
 import { SessionService } from '../../service/services/session.service';
 import { extractSessionToken } from '../common/extract-session-token';
 import { parseContext } from '../common/parse-context';
-import { REQUEST_CONTEXT_KEY, RequestContextService } from '../common/request-context.service';
+import { RequestContextService } from '../common/request-context.service';
 import { setSessionToken } from '../common/set-session-token';
 import { PERMISSIONS_METADATA_KEY } from '../decorators/allow.decorator';
 

+ 2 - 3
packages/core/src/api/middleware/id-interceptor.ts

@@ -29,10 +29,9 @@ export class IdInterceptor implements NestInterceptor {
     constructor(private idCodecService: IdCodecService) {}
 
     intercept(context: ExecutionContext, next: CallHandler<any>): Observable<any> {
-        const { isGraphQL, req } = parseContext(context);
-        if (isGraphQL) {
+        const { isGraphQL, req, info } = parseContext(context);
+        if (isGraphQL && info) {
             const args = GqlExecutionContext.create(context).getArgs();
-            const info = GqlExecutionContext.create(context).getInfo();
             const transformer = this.getTransformerForSchema(info.schema);
             this.decodeIdArguments(transformer, info.operation, args);
         }

+ 59 - 0
packages/core/src/api/middleware/transaction-interceptor.ts

@@ -0,0 +1,59 @@
+import { CallHandler, ExecutionContext, Injectable, NestInterceptor } from '@nestjs/common';
+import { Observable, of } from 'rxjs';
+import { tap } from 'rxjs/operators';
+
+import { REQUEST_CONTEXT_KEY, TRANSACTION_MANAGER_KEY } from '../../common/constants';
+import { TransactionalConnection } from '../../service/transaction/transactional-connection';
+import { parseContext } from '../common/parse-context';
+import { RequestContext } from '../common/request-context';
+
+/**
+ * @description
+ * Used by the {@link Transaction} decorator to create a transactional query runner
+ * and attach it to the RequestContext.
+ */
+@Injectable()
+export class TransactionInterceptor implements NestInterceptor {
+    constructor(private connection: TransactionalConnection) {}
+    intercept(context: ExecutionContext, next: CallHandler): Observable<any> {
+        const { isGraphQL, req } = parseContext(context);
+        const ctx = (req as any)[REQUEST_CONTEXT_KEY];
+        if (ctx) {
+            return of(this.withTransaction(ctx, () => next.handle().toPromise()));
+        } else {
+            return next.handle();
+        }
+    }
+
+    /**
+     * @description
+     * Executes the `work` function within the context of a transaction.
+     */
+    private async withTransaction<T>(ctx: RequestContext, work: () => T): Promise<T> {
+        const queryRunnerExists = !!(ctx as any)[TRANSACTION_MANAGER_KEY];
+        if (queryRunnerExists) {
+            // If a QueryRunner already exists on the RequestContext, there must be an existing
+            // outer transaction in progress. In that case, we just execute the work function
+            // as usual without needing to further wrap in a transaction.
+            return work();
+        }
+        const queryRunner = this.connection.rawConnection.createQueryRunner();
+        await queryRunner.startTransaction();
+        (ctx as any)[TRANSACTION_MANAGER_KEY] = queryRunner.manager;
+
+        try {
+            const result = await work();
+            if (queryRunner.isTransactionActive) {
+                await queryRunner.commitTransaction();
+            }
+            return result;
+        } catch (error) {
+            if (queryRunner.isTransactionActive) {
+                await queryRunner.rollbackTransaction();
+            }
+            throw error;
+        } finally {
+            await queryRunner.release();
+        }
+    }
+}

+ 6 - 15
packages/core/src/api/middleware/validate-custom-fields-interceptor.ts

@@ -1,29 +1,20 @@
 import { CallHandler, ExecutionContext, Injectable, NestInterceptor } from '@nestjs/common';
 import { GqlExecutionContext } from '@nestjs/graphql';
 import { LanguageCode } from '@vendure/common/lib/generated-types';
-import { assertNever } from '@vendure/common/lib/shared-utils';
 import {
-    DefinitionNode,
     GraphQLInputType,
     GraphQLList,
     GraphQLNonNull,
-    GraphQLResolveInfo,
     GraphQLSchema,
     OperationDefinitionNode,
     TypeNode,
 } from 'graphql';
 
-import { UserInputError } from '../../common/error/errors';
+import { REQUEST_CONTEXT_KEY } from '../../common/constants';
 import { ConfigService } from '../../config/config.service';
-import {
-    CustomFieldConfig,
-    CustomFields,
-    LocaleStringCustomFieldConfig,
-    StringCustomFieldConfig,
-} from '../../config/custom-field/custom-field-types';
+import { CustomFieldConfig, CustomFields } from '../../config/custom-field/custom-field-types';
 import { parseContext } from '../common/parse-context';
 import { RequestContext } from '../common/request-context';
-import { REQUEST_CONTEXT_KEY } from '../common/request-context.service';
 import { validateCustomFieldValue } from '../common/validate-custom-field-value';
 
 /**
@@ -44,12 +35,12 @@ export class ValidateCustomFieldsInterceptor implements NestInterceptor {
     }
 
     intercept(context: ExecutionContext, next: CallHandler<any>) {
-        const { isGraphQL } = parseContext(context);
-        if (isGraphQL) {
+        const parsedContext = parseContext(context);
+        if (parsedContext.isGraphQL) {
             const gqlExecutionContext = GqlExecutionContext.create(context);
-            const { operation, schema } = gqlExecutionContext.getInfo<GraphQLResolveInfo>();
+            const { operation, schema } = parsedContext.info;
             const variables = gqlExecutionContext.getArgs();
-            const ctx: RequestContext = gqlExecutionContext.getContext().req[REQUEST_CONTEXT_KEY];
+            const ctx: RequestContext = (parsedContext.req as any)[REQUEST_CONTEXT_KEY];
 
             if (operation.operation === 'mutation') {
                 const inputTypeNames = this.getArgumentMap(operation, schema);

+ 2 - 0
packages/core/src/common/constants.ts

@@ -5,3 +5,5 @@ import { LanguageCode } from '@vendure/common/lib/generated-types';
  * VendureConfig to ensure at least a valid LanguageCode is available.
  */
 export const DEFAULT_LANGUAGE_CODE = LanguageCode.en;
+export const TRANSACTION_MANAGER_KEY = Symbol('TRANSACTION_MANAGER');
+export const REQUEST_CONTEXT_KEY = 'vendureRequestContext';

+ 6 - 9
packages/core/src/entity/order/order.entity.ts

@@ -3,9 +3,11 @@ import { DeepPartial, ID } from '@vendure/common/lib/shared-types';
 import { Column, Entity, JoinTable, ManyToMany, ManyToOne, OneToMany } from 'typeorm';
 
 import { Calculated } from '../../common/calculated-decorator';
+import { ChannelAware } from '../../common/types/common-types';
 import { HasCustomFields } from '../../config/custom-field/custom-field-types';
 import { OrderState } from '../../service/helpers/order-state-machine/order-state';
 import { VendureEntity } from '../base/base.entity';
+import { Channel } from '../channel/channel.entity';
 import { CustomOrderFields } from '../custom-entity-fields';
 import { Customer } from '../customer/customer.entity';
 import { EntityId } from '../entity-id.decorator';
@@ -14,8 +16,6 @@ import { OrderLine } from '../order-line/order-line.entity';
 import { Payment } from '../payment/payment.entity';
 import { Promotion } from '../promotion/promotion.entity';
 import { ShippingMethod } from '../shipping-method/shipping-method.entity';
-import { ChannelAware } from '../../common/types/common-types';
-import { Channel } from '../channel/channel.entity';
 
 /**
  * @description
@@ -96,7 +96,7 @@ export class Order extends VendureEntity implements ChannelAware, HasCustomField
     @EntityId({ nullable: true })
     taxZoneId?: ID;
 
-    @ManyToMany((type) => Channel)
+    @ManyToMany(type => Channel)
     @JoinTable()
     channels: Channel[];
 
@@ -134,11 +134,8 @@ export class Order extends VendureEntity implements ChannelAware, HasCustomField
     }
 
     getOrderItems(): OrderItem[] {
-        return this.lines.reduce(
-            (items, line) => {
-                return [...items, ...line.items];
-            },
-            [] as OrderItem[],
-        );
+        return this.lines.reduce((items, line) => {
+            return [...items, ...line.items];
+        }, [] as OrderItem[]);
     }
 }

+ 0 - 1
packages/core/src/service/index.ts

@@ -33,5 +33,4 @@ export * from './services/tax-category.service';
 export * from './services/tax-rate.service';
 export * from './services/user.service';
 export * from './services/user.service';
-export * from './transaction/unit-of-work';
 export * from './transaction/transactional-connection';

+ 42 - 0
packages/core/src/service/initializer.service.ts

@@ -0,0 +1,42 @@
+import { Injectable } from '@nestjs/common';
+
+import { AdministratorService } from './services/administrator.service';
+import { ChannelService } from './services/channel.service';
+import { GlobalSettingsService } from './services/global-settings.service';
+import { PaymentMethodService } from './services/payment-method.service';
+import { RoleService } from './services/role.service';
+import { ShippingMethodService } from './services/shipping-method.service';
+import { TaxRateService } from './services/tax-rate.service';
+
+/**
+ * Only used internally to run the various service init methods in the correct
+ * sequence on bootstrap.
+ */
+@Injectable()
+export class InitializerService {
+    constructor(
+        private channelService: ChannelService,
+        private roleService: RoleService,
+        private administratorService: AdministratorService,
+        private taxRateService: TaxRateService,
+        private shippingMethodService: ShippingMethodService,
+        private paymentMethodService: PaymentMethodService,
+        private globalSettingsService: GlobalSettingsService,
+    ) {}
+
+    async onModuleInit() {
+        // IMPORTANT - why manually invoke these init methods rather than just relying on
+        // Nest's "onModuleInit" lifecycle hook within each individual service class?
+        // The reason is that the order of invokation matters. By explicitly invoking the
+        // methods below, we can e.g. guarantee that the default channel exists
+        // (channelService.initChannels()) before we try to create any roles (which assume that
+        // there is a default Channel to work with.
+        await this.globalSettingsService.initGlobalSettings();
+        await this.channelService.initChannels();
+        await this.roleService.initRoles();
+        await this.administratorService.initAdministrators();
+        await this.taxRateService.initTaxRates();
+        await this.shippingMethodService.initShippingMethods();
+        await this.paymentMethodService.initPaymentMethods();
+    }
+}

+ 4 - 31
packages/core/src/service/service.module.ts

@@ -1,4 +1,4 @@
-import { DynamicModule, Module, OnModuleInit } from '@nestjs/common';
+import { DynamicModule, Module } from '@nestjs/common';
 import { TypeOrmModule } from '@nestjs/typeorm';
 import { ConnectionOptions } from 'typeorm';
 
@@ -25,6 +25,7 @@ import { SlugValidator } from './helpers/slug-validator/slug-validator';
 import { TaxCalculator } from './helpers/tax-calculator/tax-calculator';
 import { TranslatableSaver } from './helpers/translatable-saver/translatable-saver';
 import { VerificationTokenGenerator } from './helpers/verification-token-generator/verification-token-generator';
+import { InitializerService } from './initializer.service';
 import { AdministratorService } from './services/administrator.service';
 import { AssetService } from './services/asset.service';
 import { AuthService } from './services/auth.service';
@@ -55,7 +56,6 @@ import { TaxRateService } from './services/tax-rate.service';
 import { UserService } from './services/user.service';
 import { ZoneService } from './services/zone.service';
 import { TransactionalConnection } from './transaction/transactional-connection';
-import { UnitOfWork } from './transaction/unit-of-work';
 
 const services = [
     AdministratorService,
@@ -104,7 +104,6 @@ const helpers = [
     ShippingConfiguration,
     SlugValidator,
     ExternalAuthenticationService,
-    UnitOfWork,
     TransactionalConnection,
 ];
 
@@ -120,36 +119,10 @@ let workerTypeOrmModule: DynamicModule;
  */
 @Module({
     imports: [ConfigModule, EventBusModule, WorkerServiceModule, JobQueueModule],
-    providers: [...services, ...helpers],
+    providers: [...services, ...helpers, InitializerService],
     exports: [...services, ...helpers],
 })
-export class ServiceCoreModule implements OnModuleInit {
-    constructor(
-        private channelService: ChannelService,
-        private roleService: RoleService,
-        private administratorService: AdministratorService,
-        private taxRateService: TaxRateService,
-        private shippingMethodService: ShippingMethodService,
-        private paymentMethodService: PaymentMethodService,
-        private globalSettingsService: GlobalSettingsService,
-    ) {}
-
-    async onModuleInit() {
-        // IMPORTANT - why manually invoke these init methods rather than just relying on
-        // Nest's "onModuleInit" lifecycle hook within each individual service class?
-        // The reason is that the order of invokation matters. By explicitly invoking the
-        // methods below, we can e.g. guarantee that the default channel exists
-        // (channelService.initChannels()) before we try to create any roles (which assume that
-        // there is a default Channel to work with.
-        await this.globalSettingsService.initGlobalSettings();
-        await this.channelService.initChannels();
-        await this.roleService.initRoles();
-        await this.administratorService.initAdministrators();
-        await this.taxRateService.initTaxRates();
-        await this.shippingMethodService.initShippingMethods();
-        await this.paymentMethodService.initPaymentMethods();
-    }
-}
+export class ServiceCoreModule {}
 
 /**
  * The ServiceModule is responsible for the service layer, i.e. accessing the database

+ 3 - 5
packages/core/src/service/services/collection.service.ts

@@ -1,4 +1,4 @@
-import { OnModuleInit, Optional } from '@nestjs/common';
+import { Injectable, OnModuleInit, Optional } from '@nestjs/common';
 import { InjectConnection } from '@nestjs/typeorm';
 import {
     ConfigurableOperation,
@@ -49,15 +49,13 @@ import { AssetService } from './asset.service';
 import { ChannelService } from './channel.service';
 import { FacetValueService } from './facet-value.service';
 
+@Injectable()
 export class CollectionService implements OnModuleInit {
     private rootCollection: Collection | undefined;
     private applyFiltersQueue: JobQueue<ApplyCollectionFiltersJobData>;
 
     constructor(
-        // Optional() allows the onModuleInit() hook to run with injected
-        // providers despite the request-scoped TransactionalConnection
-        // not yet having been created
-        @Optional() private connection: TransactionalConnection,
+        private connection: TransactionalConnection,
         private channelService: ChannelService,
         private assetService: AssetService,
         private facetValueService: FacetValueService,

+ 2 - 0
packages/core/src/service/services/order.service.ts

@@ -1,3 +1,4 @@
+import { Injectable } from '@nestjs/common';
 import { PaymentInput } from '@vendure/common/lib/generated-shop-types';
 import {
     AddNoteToOrderInput,
@@ -72,6 +73,7 @@ import { ProductVariantService } from './product-variant.service';
 import { PromotionService } from './promotion.service';
 import { StockMovementService } from './stock-movement.service';
 
+@Injectable()
 export class OrderService {
     constructor(
         private connection: TransactionalConnection,

+ 4 - 0
packages/core/src/service/services/tax-rate.service.ts

@@ -1,3 +1,5 @@
+import { Injectable } from '@nestjs/common';
+import { InjectConnection } from '@nestjs/typeorm';
 import {
     CreateTaxRateInput,
     DeletionResponse,
@@ -5,6 +7,7 @@ import {
     UpdateTaxRateInput,
 } from '@vendure/common/lib/generated-types';
 import { ID, PaginatedList } from '@vendure/common/lib/shared-types';
+import { Connection } from 'typeorm';
 
 import { RequestContext } from '../../api/common/request-context';
 import { EntityNotFoundError } from '../../common/error/errors';
@@ -23,6 +26,7 @@ import { patchEntity } from '../helpers/utils/patch-entity';
 import { TransactionalConnection } from '../transaction/transactional-connection';
 import { TaxRateUpdatedMessage } from '../types/tax-rate-messages';
 
+@Injectable()
 export class TaxRateService {
     /**
      * We cache all active TaxRates to avoid hitting the DB many times

+ 29 - 17
packages/core/src/service/transaction/transactional-connection.ts

@@ -1,27 +1,26 @@
-import { Injectable, Scope } from '@nestjs/common';
-import { Connection, ConnectionOptions, EntitySchema, getRepository, ObjectType, Repository } from 'typeorm';
+import { Injectable } from '@nestjs/common';
+import { InjectConnection } from '@nestjs/typeorm';
+import { Connection, EntitySchema, getRepository, ObjectType, Repository } from 'typeorm';
 import { RepositoryFactory } from 'typeorm/repository/RepositoryFactory';
 
-import { UnitOfWork } from './unit-of-work';
+import { RequestContext } from '../../api/common/request-context';
+import { TRANSACTION_MANAGER_KEY } from '../../common/constants';
 
 /**
  * @description
  * The TransactionalConnection is a wrapper around the TypeORM `Connection` object which works in conjunction
- * with the {@link UnitOfWork} class to implement per-request transactions. All services which access the
+ * with the {@link Transaction} decorator to implement per-request transactions. All services which access the
  * database should use this class rather than the raw TypeORM connection, to ensure that db changes can be
  * easily wrapped in transactions when required.
  *
  * The service layer does not need to know about the scope of a transaction, as this is covered at the
- * API level depending on the nature of the request.
- *
- * Based on the pattern outlined in
- * [this article](https://aaronboman.com/programming/2020/05/15/per-request-database-transactions-with-nestjs-and-typeorm/)
+ * API by the use of the `Transaction` decorator.
  *
  * @docsCategory data-access
  */
 @Injectable()
 export class TransactionalConnection {
-    constructor(private uow: UnitOfWork) {}
+    constructor(@InjectConnection() private connection: Connection) {}
 
     /**
      * @description
@@ -30,7 +29,7 @@ export class TransactionalConnection {
      * transactions.
      */
     get rawConnection(): Connection {
-        return this.uow.getConnection();
+        return this.connection;
     }
 
     /**
@@ -38,13 +37,26 @@ export class TransactionalConnection {
      * Gets a repository bound to the current transaction manager
      * or defaults to the current connection's call to getRepository().
      */
-    getRepository<Entity>(target: ObjectType<Entity> | EntitySchema<Entity> | string): Repository<Entity> {
-        const transactionManager = this.uow.getTransactionManager();
-        if (transactionManager) {
-            const connection = this.uow.getConnection();
-            const metadata = connection.getMetadata(target);
-            return new RepositoryFactory().create(transactionManager, metadata);
+    getRepository<Entity>(target: ObjectType<Entity> | EntitySchema<Entity> | string): Repository<Entity>;
+    getRepository<Entity>(
+        ctx: RequestContext,
+        target: ObjectType<Entity> | EntitySchema<Entity> | string,
+    ): Repository<Entity>;
+    getRepository<Entity>(
+        ctxOrTarget: RequestContext | ObjectType<Entity> | EntitySchema<Entity> | string,
+        maybeTarget?: ObjectType<Entity> | EntitySchema<Entity> | string,
+    ): Repository<Entity> {
+        if (ctxOrTarget instanceof RequestContext) {
+            const transactionManager = (ctxOrTarget as any)[TRANSACTION_MANAGER_KEY];
+            if (transactionManager && maybeTarget) {
+                const metadata = this.connection.getMetadata(maybeTarget);
+                return new RepositoryFactory().create(transactionManager, metadata);
+            } else {
+                // tslint:disable-next-line:no-non-null-assertion
+                return getRepository(maybeTarget!);
+            }
+        } else {
+            return getRepository(ctxOrTarget);
         }
-        return getRepository(target);
     }
 }

+ 0 - 42
packages/core/src/service/transaction/unit-of-work.ts

@@ -1,42 +0,0 @@
-import { Injectable, Scope } from '@nestjs/common';
-import { InjectConnection } from '@nestjs/typeorm';
-import { Connection, EntityManager } from 'typeorm';
-
-/**
- * @description
- * This class is used to wrap an entire request in a database transaction. It should
- * generally be injected at the API layer and wrap the service-layer call(s) so that
- * all DB access within the `withTransaction()` method takes place within a transaction.
- *
- * @docsCategory data-access
- */
-@Injectable()
-export class UnitOfWork {
-    private transactionManager: EntityManager | null;
-    constructor(@InjectConnection() private connection: Connection) {}
-
-    getTransactionManager(): EntityManager | null {
-        return this.transactionManager;
-    }
-
-    getConnection(): Connection {
-        return this.connection;
-    }
-
-    async withTransaction<T>(work: () => T): Promise<T> {
-        const queryRunner = this.connection.createQueryRunner();
-        await queryRunner.startTransaction();
-        this.transactionManager = queryRunner.manager;
-        try {
-            const result = await work();
-            await queryRunner.commitTransaction();
-            return result;
-        } catch (error) {
-            await queryRunner.rollbackTransaction();
-            throw error;
-        } finally {
-            await queryRunner.release();
-            this.transactionManager = null;
-        }
-    }
-}