PyTorch Quantization of model parameters for deployment on edge device

Essentially, I have a trained model (using PyTorch) that I want to deploy on an edge device (all written in C/C++) for inference. Just for context (I'm working alone on this project so I don't get much guidance). My understanding is that, at deployment, the input (inference data needs to be integers), and my model's parameters (weights and bias/activation) also need to be integers. Because I don't have "inference data", I am currently testing out my implementation/prototyping by quantizing my validation/test data and comparing the validation/test results I get using the floating point model parameters vs the results I get using quantized/integer model parameters. To make this more concrete (or succinct), I'm testing with two cases:

  • Case 1: floating point model called on floating point train and test data.
  • Case 2: quantized int model parameters called on quantized test data.

            def quantize_tensor(tensor, num_bits): 
                qmin = - (2 ** (num_bits - 1)) 
                qmax = (2 ** (num_bits - 1)) - 1 
                min_val, max_val = tensor.min(), tensor.max() 
                scale = (max_val - min_val) / (qmax - qmin)
                zero_point = qmin - min_val / scale
                zero_point = torch.round(zero_point).clamp(qmin, qmax)

                q_tensor = torch.round(tensor/scale+zero_point).clamp(qmin, qmax)

                if num_bits == 8:
                    q_tensor = q_tensor.type(torch.int8)
                elif num_bits == 16:
                    q_tensor = q_tensor.type(torch.int16)
                else:
                    q_tensor = q_tensor.type(torch.int)

                return q_tensor, scale, zero_point

Then I quantize the model's weights and the bias using this:

            def quantize_model(model, weight_bit_width=16, bias_bit_width=16):
                quantized_state_dict = {}
                scale_zp_dict = {}  # To store scale and zero-point for each parameter

                for name, param in model.state_dict().items():
                    if 'weight' in name:
                        q_param, scale, zero_point = quantize_tensor(param, weight_bit_width)
                        quantized_state_dict[name] = q_param
                        scale_zp_dict[name] = (scale, zero_point)
                    elif 'bias' in name:
                        q_param, scale, zero_point = quantize_tensor(param, bias_bit_width)
                        quantized_state_dict[name] = q_param
                        scale_zp_dict[name] = (scale, zero_point)
                    else:
                        # For other parameters, keep them as is or apply appropriate quantization
                        quantized_state_dict[name] = param

                return quantized_state_dict, scale_zp_dict

Furthermore, I quantize my model and the data like so (see code below) however, because my ML Problem is a multiclass and multioutput problem, I need to call torch.softmax on the logits I get out of my model, so I can get prediction probabilities but the softmax function doesn't support integers (or technically is not implemented for ints) which makes me worried that my overally quantization approach is wrong (I add the model's code and extra below):

import copy 

class model(nn.Module):
    def __init__(self, inputs, l1, l2, num_outputs, output_classes=3):
        super().__init__()

        # define the layers
        self.output_classes = output_classes
        self.num_outputs = num_outputs

        self.layers = nn.Sequential(
            nn.Linear(inputs, l1),
            nn.ReLU(),
            nn.Linear(l1, l2),
            nn.ReLU(),
            nn.Linear(l2, num_outputs * output_classes),  # output_classes = number of classes in each output
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.view(-1, self.output_classes, self.num_outputs)  # Reshapes output tensor (logits output).
        return x


model_copy = copy.deepcopy(floating_point_trained_model)

# quantize model params
quantized_state_dict, scale_zp_dict = quantize_model(model_copy, weight_bit_width=16, bias_bit_width=16)
for name, param in model_copy.named_parameters():
    param.requires_grad = False
    param.data = quantized_state_dict[name].to(dtype=torch.float) # <--- Need help here: Casting to float to satisfy softmax requirements 

# Quantize data 
Quant_X_train, scale, zp = quantize_tensor(X_train, 16) # can make your X_train
Quant_X_test, test_scale, test_zp = quantize_tensor(X_test, 16) # can make your X_test

# call quantized model on quantized input data 
pred_probs = torch.softmax(model_copy(Quant_X_test.to(torch.float), dim = 1) # <---Need Help: Casting to float to get prediction probabilities 
predictions = torch.argmax(pred_probs, dim=1) 

I'm curious about a few things:

  • If this is the correct process/way to approach this problem.
    • Especially because I am not able to call softmax on my int tensor, I feel I might be doing something wrong.
  • If I implemented the quantization procedures accurately
    • i.e. does my method of verifying make sense (the method being: comparing the results between case 1 and case 2 above)
  • If anyone has some guidance about how to approach this problem (or sample examples/tutorials) that'll be great. I have perused PyTorch's quantization mode support

If it helps, this is an example of what my training data looks like:

0      0.995231  0.996840  1.000000  0.998341  1.000000  1.000000  1.000000  0.998709  ...         0.000024         0.000019         0.000015         0.000016         0.000011         0.000007         0.000007         0.000015
1      0.996407  0.998568  1.000000  0.997889  1.000000  0.999954  0.999738  0.997458  ...         0.000018         0.000013         0.000011         0.000012         0.000008         0.000005         0.000006         0.000009
2      0.996083  0.999702  1.000000  0.999031  1.000000  1.000000  0.999816  0.998727  ...         0.000019         0.000013         0.000012         0.000011         0.000008         0.000006         0.000006         0.000011
3      0.998531  0.999481  0.999199  1.000000  0.999720  1.000000  1.000000  0.998682  ...         0.000015         0.000011         0.000010         0.000010         0.000007         0.000005         0.000004         0.000007