Using Recursion to Generate Feature Interaction Terms

Jose M Sallan 2024-05-31 8 min read

In the context of prediction jobs, we can be interested in generating products of powers of variables. These products can be used as features for regression-based prediction techniques like neural networks.

For small values of \(p\) and \(n\), these interactions are straightforward. For instance, for \(p=2\) and \(n=2\) these are:

  • \(x_1\)
  • \(x_2\)
  • \(x_1^2\)
  • \(x_2^2\)
  • \(x_1x_2\)

For larger values of \(p\) and \(n\) we may need a systematic way of generating those products of powers. Here I will suggest a recursion-based procedure to generate all products of up to \(p\) powers of a set of \(n\) variables and provide an implementation in R base. I will be using this result to obtain the interaction terms for a set of \(n\) features whose sum of powers is equal to \(n\).

Recursion

As defined in Wikipedia, recursion occurs when the definition of a concept or process depends on a simpler or previous version of itself. To define a recursion we need:

  • A set of base cases that do not need recursion to produce an answer. In the Fibonacci series, these are fibonacci(0) = 0 and fibonacci(1) = 1.
  • A recursive step that reduces all the other cases to the base cases. In the Fibonacci series, the step is fibonacci(n) = fibonacci(n-1) + fibonacci(n-2).

In R base the implementation of recursion is straightforward, using a function covering the base cases and the recursive step. Note that the function appears within the function definition.

fibonacci <- function(n) {
  if (n <= 1) {  # Base case: If n is 0 or 1, return n
    return(n)
  } else {
    # Recursive case: Calculate the nth term using recursion
    return(fibonacci(n - 1) + fibonacci(n - 2))
  }
}

Some values of the function

fibonacci(0)
## [1] 0
fibonacci(10)
## [1] 55

A sequence of values of the Fibonacci series:

sapply(0:20, fibonacci)
##  [1]    0    1    1    2    3    5    8   13   21   34   55   89  144  233  377
## [16]  610  987 1597 2584 4181 6765

Products of Powers

To define the products of powers, we will represent them as vectors of length \(n\) with values from 0 to \(p\). The value zero contemplates the case where a variable is not included in the product. For \(n = 3\) and \(p = 2\) the product \(x_1x_3^2\) is encoded as \((1, 0, 2)\).

In the suggested recursive process, we are adding a variable at each step. The base case is for \(n=1\). If we have \(p=3\), this base case is:

s1 <- matrix(0:3, 4, 1)

For \(n=2\) we will use the matrix of \(n=1\):

s2 <- lapply(0:3, function(i) cbind(s1, i))
s2
## [[1]]
##        i
## [1,] 0 0
## [2,] 1 0
## [3,] 2 0
## [4,] 3 0
## 
## [[2]]
##        i
## [1,] 0 1
## [2,] 1 1
## [3,] 2 1
## [4,] 3 1
## 
## [[3]]
##        i
## [1,] 0 2
## [2,] 1 2
## [3,] 2 2
## [4,] 3 2
## 
## [[4]]
##        i
## [1,] 0 3
## [2,] 1 3
## [3,] 2 3
## [4,] 3 3

We can bind the rows together doing:

do.call(rbind, s2)
##         i
##  [1,] 0 0
##  [2,] 1 0
##  [3,] 2 0
##  [4,] 3 0
##  [5,] 0 1
##  [6,] 1 1
##  [7,] 2 1
##  [8,] 3 1
##  [9,] 0 2
## [10,] 1 2
## [11,] 2 2
## [12,] 3 2
## [13,] 0 3
## [14,] 1 3
## [15,] 2 3
## [16,] 3 3

For \(n\) variables, we need to implement this process recursively:

powers <- function(p, n){
  if(n == 1){
    s0 <- matrix(0:p, p + 1, 1)
    return(s0)
  }else{
    sn <- lapply(0:p, function(i) cbind(powers(p, n-1), i))
    m <- do.call(rbind, sn)
    colnames(m) <- NULL
    return(m)
  }
}

Let’s see some examples:

powers(3, 1)
##      [,1]
## [1,]    0
## [2,]    1
## [3,]    2
## [4,]    3
powers(3, 2)
##       [,1] [,2]
##  [1,]    0    0
##  [2,]    1    0
##  [3,]    2    0
##  [4,]    3    0
##  [5,]    0    1
##  [6,]    1    1
##  [7,]    2    1
##  [8,]    3    1
##  [9,]    0    2
## [10,]    1    2
## [11,]    2    2
## [12,]    3    2
## [13,]    0    3
## [14,]    1    3
## [15,]    2    3
## [16,]    3    3
powers(3, 3)
##       [,1] [,2] [,3]
##  [1,]    0    0    0
##  [2,]    1    0    0
##  [3,]    2    0    0
##  [4,]    3    0    0
##  [5,]    0    1    0
##  [6,]    1    1    0
##  [7,]    2    1    0
##  [8,]    3    1    0
##  [9,]    0    2    0
## [10,]    1    2    0
## [11,]    2    2    0
## [12,]    3    2    0
## [13,]    0    3    0
## [14,]    1    3    0
## [15,]    2    3    0
## [16,]    3    3    0
## [17,]    0    0    1
## [18,]    1    0    1
## [19,]    2    0    1
## [20,]    3    0    1
## [21,]    0    1    1
## [22,]    1    1    1
## [23,]    2    1    1
## [24,]    3    1    1
## [25,]    0    2    1
## [26,]    1    2    1
## [27,]    2    2    1
## [28,]    3    2    1
## [29,]    0    3    1
## [30,]    1    3    1
## [31,]    2    3    1
## [32,]    3    3    1
## [33,]    0    0    2
## [34,]    1    0    2
## [35,]    2    0    2
## [36,]    3    0    2
## [37,]    0    1    2
## [38,]    1    1    2
## [39,]    2    1    2
## [40,]    3    1    2
## [41,]    0    2    2
## [42,]    1    2    2
## [43,]    2    2    2
## [44,]    3    2    2
## [45,]    0    3    2
## [46,]    1    3    2
## [47,]    2    3    2
## [48,]    3    3    2
## [49,]    0    0    3
## [50,]    1    0    3
## [51,]    2    0    3
## [52,]    3    0    3
## [53,]    0    1    3
## [54,]    1    1    3
## [55,]    2    1    3
## [56,]    3    1    3
## [57,]    0    2    3
## [58,]    1    2    3
## [59,]    2    2    3
## [60,]    3    2    3
## [61,]    0    3    3
## [62,]    1    3    3
## [63,]    2    3    3
## [64,]    3    3    3

The function generates \(p^n\) powered product terms.

Interactions of order n

Of all generated powered products, we are interested in picking only those whose sum of exponents are equal to \(n\). For \(p =2\) and \(n = 2\) these are:

  • \(x_1^2\)
  • \(x_2^2\)
  • \(x_1x_2\)

So we need to exclude \(x_1\), \(x_2\) and the intercept \(1\) from this list. We do that generating all power(n, n) interactions and filtering them adequately.

interactions <- function(n){
  
  # all powered products
  t <- powers(n, n)
  
  # rows summing n
  r <- apply(t, 1, \(x) sum(x) == n)
  
  # selectin rows summing n
  t_f <- t[r, ]
  return(t_f)
}

Then we have:

interactions(2)
##      [,1] [,2]
## [1,]    2    0
## [2,]    1    1
## [3,]    0    2

For larger values we have results like:

interactions(4)
##       [,1] [,2] [,3] [,4]
##  [1,]    4    0    0    0
##  [2,]    3    1    0    0
##  [3,]    2    2    0    0
##  [4,]    1    3    0    0
##  [5,]    0    4    0    0
##  [6,]    3    0    1    0
##  [7,]    2    1    1    0
##  [8,]    1    2    1    0
##  [9,]    0    3    1    0
## [10,]    2    0    2    0
## [11,]    1    1    2    0
## [12,]    0    2    2    0
## [13,]    1    0    3    0
## [14,]    0    1    3    0
## [15,]    0    0    4    0
## [16,]    3    0    0    1
## [17,]    2    1    0    1
## [18,]    1    2    0    1
## [19,]    0    3    0    1
## [20,]    2    0    1    1
## [21,]    1    1    1    1
## [22,]    0    2    1    1
## [23,]    1    0    2    1
## [24,]    0    1    2    1
## [25,]    0    0    3    1
## [26,]    2    0    0    2
## [27,]    1    1    0    2
## [28,]    0    2    0    2
## [29,]    1    0    1    2
## [30,]    0    1    1    2
## [31,]    0    0    2    2
## [32,]    1    0    0    3
## [33,]    0    1    0    3
## [34,]    0    0    1    3
## [35,]    0    0    0    4

Session Info

## R version 4.4.0 (2024-04-24)
## Platform: x86_64-pc-linux-gnu
## Running under: Linux Mint 21.1
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0 
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0
## 
## locale:
##  [1] LC_CTYPE=es_ES.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=es_ES.UTF-8        LC_COLLATE=es_ES.UTF-8    
##  [5] LC_MONETARY=es_ES.UTF-8    LC_MESSAGES=es_ES.UTF-8   
##  [7] LC_PAPER=es_ES.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=es_ES.UTF-8 LC_IDENTIFICATION=C       
## 
## time zone: Europe/Madrid
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## loaded via a namespace (and not attached):
##  [1] digest_0.6.35     R6_2.5.1          bookdown_0.39     fastmap_1.1.1    
##  [5] xfun_0.43         blogdown_1.19     cachem_1.0.8      knitr_1.46       
##  [9] htmltools_0.5.8.1 rmarkdown_2.26    lifecycle_1.0.4   cli_3.6.2        
## [13] sass_0.4.9        jquerylib_0.1.4   compiler_4.4.0    rstudioapi_0.16.0
## [17] tools_4.4.0       evaluate_0.23     bslib_0.7.0       yaml_2.3.8       
## [21] jsonlite_1.8.8    rlang_1.1.3