Browse Source

feat(core): Add filter method to EventBus (#1930)

Yannick Boetzkes 3 years ago
parent
commit
7eabaa718a

+ 125 - 2
packages/core/src/event-bus/event-bus.spec.ts

@@ -1,7 +1,5 @@
 import { QueryRunner } from 'typeorm';
 
-import { TransactionSubscriber } from '../connection/transaction-subscriber';
-
 import { EventBus } from './event-bus';
 import { VendureEvent } from './vendure-event';
 
@@ -125,6 +123,125 @@ describe('EventBus', () => {
             expect(handler2).toHaveBeenCalledTimes(3);
         });
     });
+
+    describe('filter()', () => {
+        it('single handler is called once', async () => {
+            const handler = jest.fn();
+            const event = new TestEvent('foo');
+            eventBus.filter(vendureEvent => vendureEvent instanceof TestEvent).subscribe(handler);
+
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler).toHaveBeenCalledTimes(1);
+            expect(handler).toHaveBeenCalledWith(event);
+        });
+
+        it('single handler is called on multiple events', async () => {
+            const handler = jest.fn();
+            const event1 = new TestEvent('foo');
+            const event2 = new TestEvent('bar');
+            const event3 = new TestEvent('baz');
+            eventBus.filter(vendureEvent => vendureEvent instanceof TestEvent).subscribe(handler);
+
+            eventBus.publish(event1);
+            eventBus.publish(event2);
+            eventBus.publish(event3);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler).toHaveBeenCalledTimes(3);
+            expect(handler).toHaveBeenCalledWith(event1);
+            expect(handler).toHaveBeenCalledWith(event2);
+            expect(handler).toHaveBeenCalledWith(event3);
+        });
+
+        it('multiple handlers are called', async () => {
+            const handler1 = jest.fn();
+            const handler2 = jest.fn();
+            const handler3 = jest.fn();
+            const event = new TestEvent('foo');
+            eventBus.filter(vendureEvent => vendureEvent instanceof TestEvent).subscribe(handler1);
+            eventBus.filter(vendureEvent => vendureEvent instanceof TestEvent).subscribe(handler2);
+            eventBus.filter(vendureEvent => vendureEvent instanceof TestEvent).subscribe(handler3);
+
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler1).toHaveBeenCalledWith(event);
+            expect(handler2).toHaveBeenCalledWith(event);
+            expect(handler3).toHaveBeenCalledWith(event);
+        });
+
+        it('handler is not called for other events', async () => {
+            const handler = jest.fn();
+            const event = new OtherTestEvent('foo');
+            eventBus.filter(vendureEvent => vendureEvent instanceof TestEvent).subscribe(handler);
+
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler).not.toHaveBeenCalled();
+        });
+
+        it('handler is called for instance of child classes', async () => {
+            const handler = jest.fn();
+            const event = new ChildTestEvent('bar', 'foo');
+            eventBus.filter(vendureEvent => vendureEvent instanceof TestEvent).subscribe(handler);
+
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler).toHaveBeenCalled();
+        });
+
+        it('filter() returns a subscription', async () => {
+            const handler = jest.fn();
+            const event = new TestEvent('foo');
+            const subscription = eventBus
+                .filter(vendureEvent => vendureEvent instanceof TestEvent)
+                .subscribe(handler);
+
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler).toHaveBeenCalledTimes(1);
+
+            subscription.unsubscribe();
+
+            eventBus.publish(event);
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler).toHaveBeenCalledTimes(1);
+        });
+
+        it('unsubscribe() only unsubscribes own handler', async () => {
+            const handler1 = jest.fn();
+            const handler2 = jest.fn();
+            const event = new TestEvent('foo');
+            const subscription1 = eventBus
+                .filter(vendureEvent => vendureEvent instanceof TestEvent)
+                .subscribe(handler1);
+            const subscription2 = eventBus
+                .filter(vendureEvent => vendureEvent instanceof TestEvent)
+                .subscribe(handler2);
+
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler1).toHaveBeenCalledTimes(1);
+            expect(handler2).toHaveBeenCalledTimes(1);
+
+            subscription1.unsubscribe();
+
+            eventBus.publish(event);
+            eventBus.publish(event);
+            await new Promise(resolve => setImmediate(resolve));
+
+            expect(handler1).toHaveBeenCalledTimes(1);
+            expect(handler2).toHaveBeenCalledTimes(3);
+        });
+    });
 });
 
 class TestEvent extends VendureEvent {
@@ -133,6 +250,12 @@ class TestEvent extends VendureEvent {
     }
 }
 
+class ChildTestEvent extends TestEvent {
+    constructor(public childPayload: string, payload: string) {
+        super(payload);
+    }
+}
+
 class OtherTestEvent extends VendureEvent {
     constructor(public payload: string) {
         super();

+ 23 - 5
packages/core/src/event-bus/event-bus.ts

@@ -3,8 +3,8 @@ import { Type } from '@vendure/common/lib/shared-types';
 import { Observable, Subject } from 'rxjs';
 import { filter, mergeMap, takeUntil } from 'rxjs/operators';
 import { EntityManager } from 'typeorm';
-import { notNullOrUndefined } from '../../../common/lib/shared-utils';
 
+import { notNullOrUndefined } from '../../../common/lib/shared-utils';
 import { RequestContext } from '../api/common/request-context';
 import { TRANSACTION_MANAGER_KEY } from '../common/constants';
 import { TransactionSubscriber, TransactionSubscriberError } from '../connection/transaction-subscriber';
@@ -81,9 +81,27 @@ export class EventBus implements OnModuleDestroy {
     ofType<T extends VendureEvent>(type: Type<T>): Observable<T> {
         return this.eventStream.asObservable().pipe(
             takeUntil(this.destroy$),
-            filter(e => (e as any).constructor === type),
+            filter(e => e.constructor === type),
             mergeMap(event => this.awaitActiveTransactions(event)),
-            filter(notNullOrUndefined)
+            filter(notNullOrUndefined),
+        ) as Observable<T>;
+    }
+
+    /**
+     * @description
+     * Returns an RxJS Observable stream of events filtered by a custom predicate.
+     * If the event contains a {@link RequestContext} object, the subscriber
+     * will only get called after any active database transactions are complete.
+     *
+     * This means that the subscriber function can safely access all updated
+     * data related to the event.
+     */
+    filter<T extends VendureEvent>(predicate: (event: VendureEvent) => boolean): Observable<T> {
+        return this.eventStream.asObservable().pipe(
+            takeUntil(this.destroy$),
+            filter(e => predicate(e)),
+            mergeMap(event => this.awaitActiveTransactions(event)),
+            filter(notNullOrUndefined),
         ) as Observable<T>;
     }
 
@@ -119,7 +137,7 @@ export class EventBus implements OnModuleDestroy {
         }
 
         const [key, ctx]: [string, RequestContext] = entry;
-        
+
         const transactionManager: EntityManager | undefined = (ctx as any)[TRANSACTION_MANAGER_KEY];
         if (!transactionManager?.queryRunner) {
             return event;
@@ -134,7 +152,7 @@ export class EventBus implements OnModuleDestroy {
             delete (newContext as any)[TRANSACTION_MANAGER_KEY];
 
             // Reassign new context
-            (event as any)[key] = newContext
+            (event as any)[key] = newContext;
 
             return event;
         } catch (e: any) {