Browse Source

feat(server): Make tax zone/price calculations configurable

Relates to #31
Michael Bromley 7 years ago
parent
commit
52ecc378e1

+ 8 - 1
server/e2e/product-category.e2e-spec.ts

@@ -14,6 +14,7 @@ import {
     LanguageCode,
     MoveProductCategory,
     ProductCategory,
+    SortOrder,
     UpdateProductCategory,
 } from '../../shared/generated-types';
 import { ROOT_CATEGORY_NAME } from '../../shared/shared-constants';
@@ -36,7 +37,13 @@ describe('ProductCategory resolver', () => {
             customerCount: 1,
         });
         await client.init();
-        const assetsResult = await client.query<GetAssetList.Query, GetAssetList.Variables>(GET_ASSET_LIST);
+        const assetsResult = await client.query<GetAssetList.Query, GetAssetList.Variables>(GET_ASSET_LIST, {
+            options: {
+                sort: {
+                    name: SortOrder.ASC,
+                },
+            },
+        });
         assets = assetsResult.assets.items;
     }, TEST_SETUP_TIMEOUT_MS);
 

+ 32 - 15
server/e2e/product.e2e-spec.ts

@@ -430,21 +430,6 @@ describe('Product resolver', () => {
         describe('variants', () => {
             let variants: ProductWithVariants.Variants[];
 
-            it('generateVariantsForProduct generates variants', async () => {
-                const result = await client.query<
-                    GenerateProductVariants.Mutation,
-                    GenerateProductVariants.Variables
-                >(GENERATE_PRODUCT_VARIANTS, {
-                    productId: newProduct.id,
-                    defaultPrice: 123,
-                    defaultSku: 'ABC',
-                });
-                variants = result.generateVariantsForProduct.variants;
-                expect(variants.length).toBe(2);
-                expect(variants[0].options.length).toBe(1);
-                expect(variants[1].options.length).toBe(1);
-            });
-
             it('generateVariantsForProduct throws with an invalid productId', async () => {
                 try {
                     await client.query<GenerateProductVariants.Mutation, GenerateProductVariants.Variables>(
@@ -461,6 +446,38 @@ describe('Product resolver', () => {
                 }
             });
 
+            it('generateVariantsForProduct throws with an invalid defaultTaxCategoryId', async () => {
+                try {
+                    await client.query<GenerateProductVariants.Mutation, GenerateProductVariants.Variables>(
+                        GENERATE_PRODUCT_VARIANTS,
+                        {
+                            productId: newProduct.id,
+                            defaultTaxCategoryId: '999',
+                        },
+                    );
+                    fail('Should have thrown');
+                } catch (err) {
+                    expect(err.message).toEqual(
+                        expect.stringContaining(`No TaxCategory with the id '999' could be found`),
+                    );
+                }
+            });
+
+            it('generateVariantsForProduct generates variants', async () => {
+                const result = await client.query<
+                    GenerateProductVariants.Mutation,
+                    GenerateProductVariants.Variables
+                >(GENERATE_PRODUCT_VARIANTS, {
+                    productId: newProduct.id,
+                    defaultPrice: 123,
+                    defaultSku: 'ABC',
+                });
+                variants = result.generateVariantsForProduct.variants;
+                expect(variants.length).toBe(2);
+                expect(variants[0].options.length).toBe(1);
+                expect(variants[1].options.length).toBe(1);
+            });
+
             it('updateProductVariants updates variants', async () => {
                 const firstVariant = variants[0];
                 const result = await client.query<

+ 0 - 6
server/src/api/common/request-context.ts

@@ -58,12 +58,6 @@ export class RequestContext {
         }
     }
 
-    get activeTaxZone(): Zone {
-        // TODO: This will vary depending on Customer data available -
-        // a customer with a billing address in another zone will alter the value etc.
-        return this.channel.defaultTaxZone;
-    }
-
     /**
      * True if the current session is authorized to access the current resolver method.
      */

+ 1 - 1
server/src/api/resolvers/zone.resolver.ts

@@ -21,7 +21,7 @@ export class ZoneResolver {
 
     @Query()
     @Allow(Permission.ReadSettings)
-    zones(@Ctx() ctx: RequestContext): Promise<Zone[]> {
+    zones(@Ctx() ctx: RequestContext): Zone[] {
         return this.zoneService.findAll(ctx);
     }
 

+ 1 - 0
server/src/config/config.service.mock.ts

@@ -27,6 +27,7 @@ export class MockConfigService implements MockClass<ConfigService> {
         promotionActions: [],
     };
     paymentOptions: {};
+    taxOptions: {};
     emailOptions: {};
     importExportOptions: {};
     orderMergeOptions = {};

+ 5 - 0
server/src/config/config.service.ts

@@ -19,6 +19,7 @@ import {
     PaymentOptions,
     PromotionOptions,
     ShippingOptions,
+    TaxOptions,
     VendureConfig,
 } from './vendure-config';
 import { VendurePlugin } from './vendure-plugin/vendure-plugin';
@@ -97,6 +98,10 @@ export class ConfigService implements VendureConfig {
         return this.activeConfig.paymentOptions;
     }
 
+    get taxOptions(): TaxOptions {
+        return this.activeConfig.taxOptions;
+    }
+
     get emailOptions(): Required<EmailOptions<any>> {
         return this.activeConfig.emailOptions as Required<EmailOptions<any>>;
     }

+ 6 - 0
server/src/config/default-config.ts

@@ -14,6 +14,8 @@ import { defaultPromotionActions } from './promotion/default-promotion-actions';
 import { defaultPromotionConditions } from './promotion/default-promotion-conditions';
 import { defaultShippingCalculator } from './shipping-method/default-shipping-calculator';
 import { defaultShippingEligibilityChecker } from './shipping-method/default-shipping-eligibility-checker';
+import { DefaultTaxCalculationStrategy } from './tax/default-tax-calculation-strategy';
+import { DefaultTaxZoneStrategy } from './tax/default-tax-zone-strategy';
 import { VendureConfig } from './vendure-config';
 
 /**
@@ -66,6 +68,10 @@ export const defaultConfig: ReadOnlyRequired<VendureConfig> = {
     paymentOptions: {
         paymentMethodHandlers: [],
     },
+    taxOptions: {
+        taxZoneStrategy: new DefaultTaxZoneStrategy(),
+        taxCalculationStrategy: new DefaultTaxCalculationStrategy(),
+    },
     emailOptions: {
         emailTemplatePath: __dirname,
         emailTypes: {},

+ 48 - 0
server/src/config/tax/default-tax-calculation-strategy.ts

@@ -0,0 +1,48 @@
+import { RequestContext } from '../../api/common/request-context';
+import { idsAreEqual } from '../../common/utils';
+import { TaxCategory } from '../../entity';
+import { TaxCalculationResult } from '../../service/helpers/tax-calculator/tax-calculator';
+import { TaxRateService } from '../../service/services/tax-rate.service';
+
+import { TaxCalculationArgs, TaxCalculationStrategy } from './tax-calculation-strategy';
+
+export class DefaultTaxCalculationStrategy implements TaxCalculationStrategy {
+    calculate(args: TaxCalculationArgs): TaxCalculationResult {
+        const { inputPrice, activeTaxZone, ctx, taxCategory, taxRateService } = args;
+        let price = 0;
+        let priceWithTax = 0;
+        let priceWithoutTax = 0;
+        let priceIncludesTax = false;
+        const taxRate = taxRateService.getApplicableTaxRate(activeTaxZone, taxCategory);
+
+        if (ctx.channel.pricesIncludeTax) {
+            const isDefaultZone = idsAreEqual(activeTaxZone.id, ctx.channel.defaultTaxZone.id);
+            const taxRateForDefaultZone = taxRateService.getApplicableTaxRate(
+                ctx.channel.defaultTaxZone,
+                taxCategory,
+            );
+            priceWithoutTax = taxRateForDefaultZone.netPriceOf(inputPrice);
+
+            if (isDefaultZone) {
+                priceIncludesTax = true;
+                price = inputPrice;
+                priceWithTax = inputPrice;
+            } else {
+                price = priceWithoutTax;
+                priceWithTax = taxRate.grossPriceOf(priceWithoutTax);
+            }
+        } else {
+            const netPrice = inputPrice;
+            price = netPrice;
+            priceWithTax = netPrice + taxRate.taxPayableOn(netPrice);
+            priceWithoutTax = netPrice;
+        }
+
+        return {
+            price,
+            priceIncludesTax,
+            priceWithTax,
+            priceWithoutTax,
+        };
+    }
+}

+ 9 - 0
server/src/config/tax/default-tax-zone-strategy.ts

@@ -0,0 +1,9 @@
+import { Channel, Order, Zone } from '../../entity';
+
+import { TaxZoneStrategy } from './tax-zone-strategy';
+
+export class DefaultTaxZoneStrategy implements TaxZoneStrategy {
+    determineTaxZone(zones: Zone[], channel: Channel, order?: Order): Zone {
+        return channel.defaultTaxZone;
+    }
+}

+ 19 - 0
server/src/config/tax/tax-calculation-strategy.ts

@@ -0,0 +1,19 @@
+import { RequestContext } from '../../api/common/request-context';
+import { TaxCategory, Zone } from '../../entity';
+import { TaxCalculationResult } from '../../service/helpers/tax-calculator/tax-calculator';
+import { TaxRateService } from '../../service/services/tax-rate.service';
+
+export interface TaxCalculationArgs {
+    inputPrice: number;
+    taxCategory: TaxCategory;
+    activeTaxZone: Zone;
+    ctx: RequestContext;
+    taxRateService: TaxRateService;
+}
+
+/**
+ * Defines how taxes are calculated based on the input price, tax zone and current request context.
+ */
+export interface TaxCalculationStrategy {
+    calculate(args: TaxCalculationArgs): TaxCalculationResult;
+}

+ 8 - 0
server/src/config/tax/tax-zone-strategy.ts

@@ -0,0 +1,8 @@
+import { Channel, Order, Zone } from '../../entity';
+
+/**
+ * Defines how the active Zone is determined for the purposes of calculating taxes.
+ */
+export interface TaxZoneStrategy {
+    determineTaxZone(zones: Zone[], channel: Channel, order?: Order): Zone;
+}

+ 17 - 0
server/src/config/vendure-config.ts

@@ -21,6 +21,8 @@ import { PromotionAction } from './promotion/promotion-action';
 import { PromotionCondition } from './promotion/promotion-condition';
 import { ShippingCalculator } from './shipping-method/shipping-calculator';
 import { ShippingEligibilityChecker } from './shipping-method/shipping-eligibility-checker';
+import { TaxCalculationStrategy } from './tax/tax-calculation-strategy';
+import { TaxZoneStrategy } from './tax/tax-zone-strategy';
 import { VendurePlugin } from './vendure-plugin/vendure-plugin';
 
 export interface AuthOptions {
@@ -183,6 +185,17 @@ export interface PaymentOptions {
     paymentMethodHandlers: Array<PaymentMethodHandler<any>>;
 }
 
+export interface TaxOptions {
+    /**
+     * Defines the strategy used to determine the applicable Zone used in tax calculations.
+     */
+    taxZoneStrategy: TaxZoneStrategy;
+    /**
+     * Defines the strategy used for calculating taxes.
+     */
+    taxCalculationStrategy: TaxCalculationStrategy;
+}
+
 export interface ImportExportOptions {
     /**
      * The directory in which assets to be imported are located.
@@ -270,6 +283,10 @@ export interface VendureConfig {
      * Configures available payment processing methods.
      */
     paymentOptions: PaymentOptions;
+    /**
+     * Configures how taxes are calculated on products.
+     */
+    taxOptions?: TaxOptions;
     /**
      * Configures the handling of transactional emails.
      */

+ 16 - 5
server/src/service/helpers/order-calculator/order-calculator.spec.ts

@@ -2,11 +2,16 @@ import { Test } from '@nestjs/testing';
 import { Connection } from 'typeorm';
 
 import { Omit } from '../../../../../shared/omit';
+import { ConfigService } from '../../../config/config.service';
+import { MockConfigService } from '../../../config/config.service.mock';
+import { DefaultTaxCalculationStrategy } from '../../../config/tax/default-tax-calculation-strategy';
+import { DefaultTaxZoneStrategy } from '../../../config/tax/default-tax-zone-strategy';
 import { OrderItem } from '../../../entity/order-item/order-item.entity';
 import { OrderLine } from '../../../entity/order-line/order-line.entity';
 import { Order } from '../../../entity/order/order.entity';
 import { TaxCategory } from '../../../entity/tax-category/tax-category.entity';
 import { TaxRateService } from '../../services/tax-rate.service';
+import { ZoneService } from '../../services/zone.service';
 import { ListQueryBuilder } from '../list-query-builder/list-query-builder';
 import { ShippingCalculator } from '../shipping-calculator/shipping-calculator';
 import { TaxCalculator } from '../tax-calculator/tax-calculator';
@@ -14,7 +19,6 @@ import {
     createRequestContext,
     MockConnection,
     taxCategoryStandard,
-    zoneDefault,
 } from '../tax-calculator/tax-calculator-test-fixtures';
 
 import { OrderCalculator } from './order-calculator';
@@ -31,10 +35,17 @@ describe('OrderCalculator', () => {
                 { provide: ShippingCalculator, useValue: { getEligibleShippingMethods: () => [] } },
                 { provide: Connection, useClass: MockConnection },
                 { provide: ListQueryBuilder, useValue: {} },
+                { provide: ConfigService, useClass: MockConfigService },
+                { provide: ZoneService, useValue: { findAll: () => [] } },
             ],
         }).compile();
 
         orderCalculator = module.get(OrderCalculator);
+        const mockConfigService = module.get<ConfigService, MockConfigService>(ConfigService);
+        mockConfigService.taxOptions = {
+            taxZoneStrategy: new DefaultTaxZoneStrategy(),
+            taxCalculationStrategy: new DefaultTaxCalculationStrategy(),
+        };
         const taxRateService = module.get(TaxRateService);
         await taxRateService.initTaxRates();
     });
@@ -64,7 +75,7 @@ describe('OrderCalculator', () => {
 
     describe('taxes only', () => {
         it('single line with taxes not included', async () => {
-            const ctx = createRequestContext(false, zoneDefault);
+            const ctx = createRequestContext(false);
             const order = createOrder({
                 lines: [{ unitPrice: 123, taxCategory: taxCategoryStandard, quantity: 1 }],
             });
@@ -75,7 +86,7 @@ describe('OrderCalculator', () => {
         });
 
         it('single line with taxes not included, multiple items', async () => {
-            const ctx = createRequestContext(false, zoneDefault);
+            const ctx = createRequestContext(false);
             const order = createOrder({
                 lines: [{ unitPrice: 123, taxCategory: taxCategoryStandard, quantity: 3 }],
             });
@@ -86,7 +97,7 @@ describe('OrderCalculator', () => {
         });
 
         it('single line with taxes included', async () => {
-            const ctx = createRequestContext(true, zoneDefault);
+            const ctx = createRequestContext(true);
             const order = createOrder({
                 lines: [{ unitPrice: 123, taxCategory: taxCategoryStandard, quantity: 1 }],
             });
@@ -97,7 +108,7 @@ describe('OrderCalculator', () => {
         });
 
         it('resets totals when lines array is empty', async () => {
-            const ctx = createRequestContext(true, zoneDefault);
+            const ctx = createRequestContext(true);
             const order = createOrder({
                 lines: [],
                 subTotal: 148,

+ 10 - 3
server/src/service/helpers/order-calculator/order-calculator.ts

@@ -4,17 +4,21 @@ import { AdjustmentType } from '../../../../../shared/generated-types';
 import { ID } from '../../../../../shared/shared-types';
 import { RequestContext } from '../../../api/common/request-context';
 import { idsAreEqual } from '../../../common/utils';
+import { ConfigService } from '../../../config/config.service';
 import { Order } from '../../../entity/order/order.entity';
 import { Promotion } from '../../../entity/promotion/promotion.entity';
 import { ShippingMethod } from '../../../entity/shipping-method/shipping-method.entity';
 import { Zone } from '../../../entity/zone/zone.entity';
 import { TaxRateService } from '../../services/tax-rate.service';
+import { ZoneService } from '../../services/zone.service';
 import { ShippingCalculator } from '../shipping-calculator/shipping-calculator';
 import { TaxCalculator } from '../tax-calculator/tax-calculator';
 
 @Injectable()
 export class OrderCalculator {
     constructor(
+        private configService: ConfigService,
+        private zoneService: ZoneService,
         private taxRateService: TaxRateService,
         private taxCalculator: TaxCalculator,
         private shippingCalculator: ShippingCalculator,
@@ -24,16 +28,18 @@ export class OrderCalculator {
      * Applies taxes and promotions to an Order. Mutates the order object.
      */
     async applyPriceAdjustments(ctx: RequestContext, order: Order, promotions: Promotion[]): Promise<Order> {
-        const activeZone = ctx.channel.defaultTaxZone;
+        const { taxZoneStrategy } = this.configService.taxOptions;
+        const zones = this.zoneService.findAll(ctx);
+        const activeTaxZone = taxZoneStrategy.determineTaxZone(zones, ctx.channel, order);
         order.clearAdjustments();
         if (order.lines.length) {
             // First apply taxes to the non-discounted prices
-            this.applyTaxes(ctx, order, activeZone);
+            this.applyTaxes(ctx, order, activeTaxZone);
             // Then test and apply promotions
             this.applyPromotions(order, promotions);
             // Finally, re-calculate taxes because the promotions may have
             // altered the unit prices, which in turn will alter the tax payable.
-            this.applyTaxes(ctx, order, activeZone);
+            this.applyTaxes(ctx, order, activeTaxZone);
             await this.applyShipping(ctx, order);
         } else {
             this.calculateOrderTotals(order);
@@ -52,6 +58,7 @@ export class OrderCalculator {
             const { price, priceIncludesTax, priceWithTax, priceWithoutTax } = this.taxCalculator.calculate(
                 line.unitPrice,
                 line.taxCategory,
+                activeZone,
                 ctx,
             );
 

+ 1 - 5
server/src/service/helpers/tax-calculator/tax-calculator-test-fixtures.ts

@@ -70,7 +70,7 @@ export class MockConnection {
     }
 }
 
-export function createRequestContext(pricesIncludeTax: boolean, activeTaxZone: Zone): RequestContext {
+export function createRequestContext(pricesIncludeTax: boolean): RequestContext {
     const channel = new Channel({
         defaultTaxZone: zoneDefault,
         pricesIncludeTax,
@@ -82,9 +82,5 @@ export function createRequestContext(pricesIncludeTax: boolean, activeTaxZone: Z
         isAuthorized: true,
         session: {} as any,
     });
-    // TODO: Hack until we implement the other ways of
-    // calculating the activeTaxZone (customer billing address etc)
-    delete Object.getPrototypeOf(ctx).activeTaxZone;
-    (ctx as any).activeTaxZone = activeTaxZone;
     return ctx;
 }

+ 29 - 20
server/src/service/helpers/tax-calculator/tax-calculator.spec.ts

@@ -1,6 +1,10 @@
 import { Test } from '@nestjs/testing';
 import { Connection } from 'typeorm';
 
+import { ConfigService } from '../../../config/config.service';
+import { MockConfigService } from '../../../config/config.service.mock';
+import { MergeOrdersStrategy } from '../../../config/order-merge-strategy/merge-orders-strategy';
+import { DefaultTaxCalculationStrategy } from '../../../config/tax/default-tax-calculation-strategy';
 import { TaxRateService } from '../../services/tax-rate.service';
 import { ListQueryBuilder } from '../list-query-builder/list-query-builder';
 
@@ -28,6 +32,7 @@ describe('TaxCalculator', () => {
             providers: [
                 TaxCalculator,
                 TaxRateService,
+                { provide: ConfigService, useClass: MockConfigService },
                 { provide: Connection, useClass: MockConnection },
                 { provide: ListQueryBuilder, useValue: {} },
             ],
@@ -35,13 +40,17 @@ describe('TaxCalculator', () => {
 
         taxCalculator = module.get(TaxCalculator);
         const taxRateService = module.get(TaxRateService);
+        const mockConfigService = module.get<ConfigService, MockConfigService>(ConfigService);
+        mockConfigService.taxOptions = {
+            taxCalculationStrategy: new DefaultTaxCalculationStrategy(),
+        };
         await taxRateService.initTaxRates();
     });
 
     describe('with prices which do not include tax', () => {
         it('standard tax, default zone', () => {
-            const ctx = createRequestContext(false, zoneDefault);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, ctx);
+            const ctx = createRequestContext(false);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, zoneDefault, ctx);
 
             expect(result).toEqual({
                 price: inputPrice,
@@ -52,8 +61,8 @@ describe('TaxCalculator', () => {
         });
 
         it('reduced tax, default zone', () => {
-            const ctx = createRequestContext(false, zoneDefault);
-            const result = taxCalculator.calculate(6543, taxCategoryReduced, ctx);
+            const ctx = createRequestContext(false);
+            const result = taxCalculator.calculate(6543, taxCategoryReduced, zoneDefault, ctx);
 
             expect(result).toEqual({
                 price: inputPrice,
@@ -64,8 +73,8 @@ describe('TaxCalculator', () => {
         });
 
         it('standard tax, other zone', () => {
-            const ctx = createRequestContext(false, zoneOther);
-            const result = taxCalculator.calculate(6543, taxCategoryStandard, ctx);
+            const ctx = createRequestContext(false);
+            const result = taxCalculator.calculate(6543, taxCategoryStandard, zoneOther, ctx);
 
             expect(result).toEqual({
                 price: inputPrice,
@@ -76,8 +85,8 @@ describe('TaxCalculator', () => {
         });
 
         it('reduced tax, other zone', () => {
-            const ctx = createRequestContext(false, zoneOther);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, ctx);
+            const ctx = createRequestContext(false);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, zoneOther, ctx);
 
             expect(result).toEqual({
                 price: inputPrice,
@@ -88,8 +97,8 @@ describe('TaxCalculator', () => {
         });
 
         it('standard tax, unconfigured zone', () => {
-            const ctx = createRequestContext(false, zoneWithNoTaxRate);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, ctx);
+            const ctx = createRequestContext(false);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, zoneWithNoTaxRate, ctx);
 
             expect(result).toEqual({
                 price: inputPrice,
@@ -102,8 +111,8 @@ describe('TaxCalculator', () => {
 
     describe('with prices which include tax', () => {
         it('standard tax, default zone', () => {
-            const ctx = createRequestContext(true, zoneDefault);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, ctx);
+            const ctx = createRequestContext(true);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, zoneDefault, ctx);
 
             expect(result).toEqual({
                 price: inputPrice,
@@ -114,8 +123,8 @@ describe('TaxCalculator', () => {
         });
 
         it('reduced tax, default zone', () => {
-            const ctx = createRequestContext(true, zoneDefault);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, ctx);
+            const ctx = createRequestContext(true);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, zoneDefault, ctx);
 
             expect(result).toEqual({
                 price: inputPrice,
@@ -126,8 +135,8 @@ describe('TaxCalculator', () => {
         });
 
         it('standard tax, other zone', () => {
-            const ctx = createRequestContext(true, zoneOther);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, ctx);
+            const ctx = createRequestContext(true);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, zoneOther, ctx);
 
             expect(result).toEqual({
                 price: taxRateDefaultStandard.netPriceOf(inputPrice),
@@ -140,8 +149,8 @@ describe('TaxCalculator', () => {
         });
 
         it('reduced tax, other zone', () => {
-            const ctx = createRequestContext(true, zoneOther);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, ctx);
+            const ctx = createRequestContext(true);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryReduced, zoneOther, ctx);
 
             expect(result).toEqual({
                 price: taxRateDefaultReduced.netPriceOf(inputPrice),
@@ -152,8 +161,8 @@ describe('TaxCalculator', () => {
         });
 
         it('standard tax, unconfigured zone', () => {
-            const ctx = createRequestContext(true, zoneWithNoTaxRate);
-            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, ctx);
+            const ctx = createRequestContext(true);
+            const result = taxCalculator.calculate(inputPrice, taxCategoryStandard, zoneWithNoTaxRate, ctx);
 
             expect(result).toEqual({
                 price: taxRateDefaultStandard.netPriceOf(inputPrice),

+ 16 - 37
server/src/service/helpers/tax-calculator/tax-calculator.ts

@@ -7,6 +7,7 @@ import { TaxCategory } from '../../../entity/tax-category/tax-category.entity';
 import { TaxRate } from '../../../entity/tax-rate/tax-rate.entity';
 import { Zone } from '../../../entity/zone/zone.entity';
 
+import { ConfigService } from '../../../config/config.service';
 import { TaxRateService } from '../../services/tax-rate.service';
 
 export interface TaxCalculationResult {
@@ -18,47 +19,25 @@ export interface TaxCalculationResult {
 
 @Injectable()
 export class TaxCalculator {
-    constructor(private taxRateService: TaxRateService) {}
+    constructor(private configService: ConfigService, private taxRateService: TaxRateService) {}
 
     /**
      * Given a price and TacxCategory, this method calculates the applicable tax rate and returns the adjusted
      * price along with other contextual information.
      */
-    calculate(inputPrice: number, taxCategory: TaxCategory, ctx: RequestContext): TaxCalculationResult {
-        let price = 0;
-        let priceWithTax = 0;
-        let priceWithoutTax = 0;
-        let priceIncludesTax = false;
-        const taxRate = this.taxRateService.getApplicableTaxRate(ctx.activeTaxZone, taxCategory);
-
-        if (ctx.channel.pricesIncludeTax) {
-            const isDefaultZone = idsAreEqual(ctx.activeTaxZone.id, ctx.channel.defaultTaxZone.id);
-            const taxRateForDefaultZone = this.taxRateService.getApplicableTaxRate(
-                ctx.channel.defaultTaxZone,
-                taxCategory,
-            );
-            priceWithoutTax = taxRateForDefaultZone.netPriceOf(inputPrice);
-
-            if (isDefaultZone) {
-                priceIncludesTax = true;
-                price = inputPrice;
-                priceWithTax = inputPrice;
-            } else {
-                price = priceWithoutTax;
-                priceWithTax = taxRate.grossPriceOf(priceWithoutTax);
-            }
-        } else {
-            const netPrice = inputPrice;
-            price = netPrice;
-            priceWithTax = netPrice + taxRate.taxPayableOn(netPrice);
-            priceWithoutTax = netPrice;
-        }
-
-        return {
-            price,
-            priceIncludesTax,
-            priceWithTax,
-            priceWithoutTax,
-        };
+    calculate(
+        inputPrice: number,
+        taxCategory: TaxCategory,
+        activeTaxZone: Zone,
+        ctx: RequestContext,
+    ): TaxCalculationResult {
+        const { taxCalculationStrategy } = this.configService.taxOptions;
+        return taxCalculationStrategy.calculate({
+            activeTaxZone,
+            taxRateService: this.taxRateService,
+            taxCategory,
+            ctx,
+            inputPrice,
+        });
     }
 }

+ 19 - 4
server/src/service/services/product-variant.service.ts

@@ -10,6 +10,8 @@ import { DEFAULT_LANGUAGE_CODE } from '../../common/constants';
 import { EntityNotFoundError, InternalServerError } from '../../common/error/errors';
 import { Translated } from '../../common/types/locale-types';
 import { assertFound, idsAreEqual } from '../../common/utils';
+import { ConfigService } from '../../config/config.service';
+import { TaxCategory } from '../../entity';
 import { FacetValue } from '../../entity/facet-value/facet-value.entity';
 import { ProductOption } from '../../entity/product-option/product-option.entity';
 import { ProductVariantTranslation } from '../../entity/product-variant/product-variant-translation.entity';
@@ -18,21 +20,25 @@ import { Product } from '../../entity/product/product.entity';
 import { AssetUpdater } from '../helpers/asset-updater/asset-updater';
 import { TaxCalculator } from '../helpers/tax-calculator/tax-calculator';
 import { TranslatableSaver } from '../helpers/translatable-saver/translatable-saver';
+import { getEntityOrThrow } from '../helpers/utils/get-entity-or-throw';
 import { translateDeep } from '../helpers/utils/translate-entity';
 
 import { FacetValueService } from './facet-value.service';
 import { TaxCategoryService } from './tax-category.service';
 import { TaxRateService } from './tax-rate.service';
+import { ZoneService } from './zone.service';
 
 @Injectable()
 export class ProductVariantService {
     constructor(
         @InjectConnection() private connection: Connection,
+        private configService: ConfigService,
         private taxCategoryService: TaxCategoryService,
         private facetValueService: FacetValueService,
         private taxRateService: TaxRateService,
         private taxCalculator: TaxCalculator,
         private assetUpdater: AssetUpdater,
+        private zoneService: ZoneService,
         private translatableSaver: TranslatableSaver,
     ) {}
 
@@ -169,8 +175,13 @@ export class ProductVariantService {
             ? generateAllCombinations(product.optionGroups.map(g => g.options))
             : [[]];
 
-        // TODO: how to handle default tax category?
-        const taxCategoryId = defaultTaxCategoryId || '1';
+        let taxCategory: TaxCategory;
+        if (defaultTaxCategoryId) {
+            taxCategory = await getEntityOrThrow(this.connection, TaxCategory, defaultTaxCategoryId);
+        } else {
+            const taxCategories = await this.taxCategoryService.findAll();
+            taxCategory = taxCategories[0];
+        }
 
         const variants: ProductVariant[] = [];
         for (const options of optionCombinations) {
@@ -179,7 +190,7 @@ export class ProductVariantService {
                 sku: defaultSku || 'sku-not-set',
                 price: defaultPrice || 0,
                 optionIds: options.map(o => o.id) as string[],
-                taxCategoryId,
+                taxCategoryId: taxCategory.id as string,
                 translations: [
                     {
                         languageCode: ctx.languageCode,
@@ -201,14 +212,18 @@ export class ProductVariantService {
         if (!channelPrice) {
             throw new InternalServerError(`error.no-price-found-for-channel`);
         }
+        const { taxZoneStrategy } = this.configService.taxOptions;
+        const zones = this.zoneService.findAll(ctx);
+        const activeTaxZone = taxZoneStrategy.determineTaxZone(zones, ctx.channel);
         const applicableTaxRate = this.taxRateService.getApplicableTaxRate(
-            ctx.activeTaxZone,
+            activeTaxZone,
             variant.taxCategory,
         );
 
         const { price, priceIncludesTax, priceWithTax, priceWithoutTax } = this.taxCalculator.calculate(
             channelPrice.price,
             variant.taxCategory,
+            activeTaxZone,
             ctx,
         );
 

+ 25 - 17
server/src/service/services/zone.service.ts

@@ -1,4 +1,4 @@
-import { Injectable } from '@nestjs/common';
+import { Injectable, OnModuleInit } from '@nestjs/common';
 import { InjectConnection } from '@nestjs/typeorm';
 import { Connection } from 'typeorm';
 
@@ -19,24 +19,22 @@ import { patchEntity } from '../helpers/utils/patch-entity';
 import { translateDeep } from '../helpers/utils/translate-entity';
 
 @Injectable()
-export class ZoneService {
+export class ZoneService implements OnModuleInit {
+    /**
+     * We cache all Zones to avoid hitting the DB many times per request.
+     */
+    private zones: Zone[] = [];
     constructor(@InjectConnection() private connection: Connection) {}
 
-    findAll(ctx: RequestContext): Promise<Zone[]> {
-        return this.connection
-            .getRepository(Zone)
-            .find({
-                relations: ['members'],
-            })
-            .then(zones => {
-                zones.forEach(
-                    zone =>
-                        (zone.members = zone.members.map(country =>
-                            translateDeep(country, ctx.languageCode),
-                        )),
-                );
-                return zones;
-            });
+    onModuleInit() {
+        return this.updateZonesCache();
+    }
+
+    findAll(ctx: RequestContext): Zone[] {
+        return this.zones.map(zone => {
+            zone.members = zone.members.map(country => translateDeep(country, ctx.languageCode));
+            return zone;
+        });
     }
 
     findOne(ctx: RequestContext, zoneId: ID): Promise<Zone | undefined> {
@@ -59,6 +57,7 @@ export class ZoneService {
             zone.members = await this.getCountriesFromIds(input.memberIds);
         }
         const newZone = await this.connection.getRepository(Zone).save(zone);
+        await this.updateZonesCache();
         return assertFound(this.findOne(ctx, newZone.id));
     }
 
@@ -66,6 +65,7 @@ export class ZoneService {
         const zone = await getEntityOrThrow(this.connection, Zone, input.id);
         const updatedZone = patchEntity(zone, input);
         await this.connection.getRepository(Zone).save(updatedZone);
+        await this.updateZonesCache();
         return assertFound(this.findOne(ctx, zone.id));
     }
 
@@ -75,6 +75,7 @@ export class ZoneService {
         const members = unique(zone.members.concat(countries), 'id');
         zone.members = members;
         await this.connection.getRepository(Zone).save(zone);
+        await this.updateZonesCache();
         return assertFound(this.findOne(ctx, zone.id));
     }
 
@@ -85,10 +86,17 @@ export class ZoneService {
         const zone = await getEntityOrThrow(this.connection, Zone, input.zoneId, { relations: ['members'] });
         zone.members = zone.members.filter(country => !input.memberIds.includes(country.id as string));
         await this.connection.getRepository(Zone).save(zone);
+        await this.updateZonesCache();
         return assertFound(this.findOne(ctx, zone.id));
     }
 
     private getCountriesFromIds(ids: ID[]): Promise<Country[]> {
         return this.connection.getRepository(Country).findByIds(ids);
     }
+
+    private async updateZonesCache() {
+        this.zones = await this.connection.getRepository(Zone).find({
+            relations: ['members'],
+        });
+    }
 }