Source code for dswe.covmatch

# Copyright (c) 2022 Pratyush Kumar, Abhinav Prakash, and Yu Ding

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from .utils import validate_matching
from ._covmatch_subroutine import *


[docs]class CovMatch(object): """ Parameters ---------- Xlist: list A list, consisting of data sets to match, also each of the individual data set can be a matrix with each column corresponding to one input variable. ylist: list A list, consisting of data sets to match, and each list is a array corresponds to target values of the data sets. circ_pos: list or int A list or array stating the column position of circular variables. An integer when only one circular variable present. Default is set to None. thresh: float or list A numerical or a list of threshold values for each covariates, against which matching happens. It should be a single value or a list of values representing threshold for each of the covariate. Default value is 0.2. priority: bool A boolean, default value False, otherwise computes the sequence of matching. Default is False. Returns ------- CovMatch self with trained parameters. \n - matched_data_X: The variable values of datasets after matching. - matched_data_y: The response values of datasets after matching (if provided, otherwise None). - min_max_original: The minimum and maximum value in original data for each covariates used in matching. - min_max_matched: The minimum and maximum value in matched data for each covariates used in matching. """ def __init__(self, Xlist, ylist=None, circ_pos=None, thresh=0.2, priority=False): validate_matching(Xlist, ylist) if circ_pos: if not (isinstance(circ_pos, list) or isinstance(circ_pos, np.ndarray) or type(circ_pos) == int): raise ValueError( "The circ_pos should be a list or 1d-array or single integer value or set to None.") if type(circ_pos) == int: circ_pos = [circ_pos] if (isinstance(thresh, list) or isinstance(thresh, np.ndarray)): if len(thresh) > 0: if len(thresh) != Xlist[0].shape[1]: raise ValueError( "The thresh must be a single value, or list or 1d array with weight for each covariate.") if type(priority) != type(True): raise ValueError("The priority must be either True or False.") self.Xlist = Xlist self.ylist = ylist self.Xlist[0] = np.array(self.Xlist[0]) self.Xlist[1] = np.array(self.Xlist[1]) if self.ylist: self.ylist[0] = np.array(self.ylist[0]).reshape(-1, 1) self.ylist[1] = np.array(self.ylist[1]).reshape(-1, 1) if priority: idx = np.argsort( -(np.abs(np.mean(self.Xlist[0], axis=0) - np.mean(self.Xlist[1], axis=0)))) self.Xlist[0] = self.Xlist[0][:, idx] self.Xlist[1] = self.Xlist[1][:, idx] self.thresh = thresh self.circ_pos = circ_pos datalist_X = [[self.Xlist[0], self.Xlist[1]], [self.Xlist[1], self.Xlist[0]]] if self.ylist: datalist_y = [[self.ylist[0], self.ylist[1]], [self.ylist[1], self.ylist[0]]] _matched_X = [[]] * 2 if self.ylist: _matched_y = [[]] * 2 for i in range(2): if self.ylist: _matched_X[i], _matched_y[i] = matching( datalist_X[i], datalist_y[i], self.thresh, self.circ_pos) else: _matched_X[i] = matching( datalist_X[i], self.ylist, self.thresh, self.circ_pos) matched1_X = [_matched_X[0][1], _matched_X[0][0]] matched2_X = [_matched_X[1][0], _matched_X[1][1]] if self.ylist: matched1_y = [_matched_y[0][1], _matched_y[0][0]] matched2_y = [_matched_y[1][0], _matched_y[1][1]] self.matched_data_X = [[]] * 2 idx0 = np.sort(np.unique(np.concatenate( [matched1_X[1], matched2_X[1]]).astype(float), axis=0, return_index=True)[1]) idx1 = np.sort(np.unique(np.concatenate( [matched1_X[0], matched2_X[0]]).astype(float), axis=0, return_index=True)[1]) self.matched_data_X[0] = np.concatenate( [matched1_X[1], matched2_X[1]]).astype(float)[idx0] self.matched_data_X[1] = np.concatenate( [matched1_X[0], matched2_X[0]]).astype(float)[idx1] self.matched_data_y = None if self.ylist: self.matched_data_y = [[]] * 2 self.matched_data_y[0] = np.squeeze(np.concatenate( [matched1_y[1], matched2_y[1]]).astype(float)[idx0]) self.matched_data_y[1] = np.squeeze(np.concatenate( [matched1_y[0], matched2_y[0]]).astype(float)[idx1]) self.min_max_original = min_max(self.Xlist) self.min_max_matched = min_max(self.matched_data_X)