Skip to content

Runtime Dispatch for Templated CUDA Kernels

When writing CUDA code, you often face a trade-off:

  • Compile-time parameters (via templates) allow the compiler to fully unroll loops and optimize aggressively.
  • Runtime parameters make the host code more flexible, but you lose compile-time benefits unless you generate every possibility.

If you try to mix both, you often end up with a big ugly switch statement that lists all possible template instantiations.

This post shows a clean way to solve this problem using a dispatch table and a clean separation between host and device code.

The Problem

Imagine you have a CUDA kernel templated on two integers M and N:

cpp
template<int M, int N>
__global__ void myKernel(const float* in, float* out);

You want:

  • Device side: M and N known at compile time for full unrolling and optimization.
  • Host side: M and N passed in at runtime.

Naively, you might write:

cpp
switch (M) {
  case 1:
    switch (N) {
      case 1: myKernel<1,1><<<...>>>(); break;
      case 2: myKernel<1,2><<<...>>>(); break;
      // ...
    }
    break;
  // ...
}

But this quickly becomes unmanageable.

The Solution: Dispatch Table

We can store function pointers to specific template instantiations in a table, and look them up at runtime.

File Structure

We’ll split the code into:

caller.cpp        <-- host-only code, runtime M & N
cudaFunction.h    <-- interface only, host-safe
cudaFunction.cu   <-- implementation, compiled by nvcc

cudaFunction.h

This header exposes a clean interface:

cpp
#pragma once
#include <utility> // for std::pair

// Public API: callable from pure host code
void launchCudaFunction(int M, int N, const float* d_in, float* d_out);

cudaFunction.cu

This is compiled with nvcc and contains all CUDA-specific code:

cpp
#include "cudaFunction.h"
#include <cuda_runtime.h>
#include <unordered_map>
#include <stdexcept>

template<int M, int N>
__global__ void myKernel(const float* in, float* out) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < M * N) {
        out[idx] = in[idx] * 2.0f; // Example work
    }
}

template<int M, int N>
void kernelLauncher(const float* in, float* out) {
    dim3 threads(256);
    dim3 blocks((M * N + threads.x - 1) / threads.x);
    myKernel<M, N><<<blocks, threads>>>(in, out);
}

using Key = std::pair<int,int>;
using LauncherFn = void (*)(const float*, float*);

// Robust hash_combine function (similar to boost::hash_combine)
inline std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
    return lhs;
}

struct PairHash {
    std::size_t operator()(const Key& k) const {
        std::size_t temp = std::hash<int>()(k.first);
        std::size_t temp2 = std::hash<int>()(k.second);
        return hash_combine(temp, temp2);
    }
};

static const std::unordered_map<Key, LauncherFn, PairHash> dispatchTable = {
    {{1, 1}, kernelLauncher<1, 1>},
    {{1, 2}, kernelLauncher<1, 2>},
    {{2, 4}, kernelLauncher<2, 4>},
    {{4, 8}, kernelLauncher<4, 8>}
};

void launchCudaFunction(int M, int N, const float* d_in, float* d_out) {
    if (auto it = dispatchTable.find({M, N}); it != dispatchTable.end()) {
        it->second(d_in, d_out);
    } else {
        throw std::runtime_error("Unsupported M, N combination");
    }
}

caller.cpp

This is pure host code and does not need nvcc:

cpp
#include "cudaFunction.h"
#include <iostream>
#include <vector>
#include <cuda_runtime.h>

int main() {
    int M = 2; // runtime value
    int N = 4; // runtime value
    size_t size = M * N;

    float* d_in;
    float* d_out;
    cudaMalloc(&d_in, size * sizeof(float));
    cudaMalloc(&d_out, size * sizeof(float));

    std::vector<float> h_in(size, 1.0f);
    cudaMemcpy(d_in, h_in.data(), size * sizeof(float), cudaMemcpyHostToDevice);

    try {
        launchCudaFunction(M, N, d_in, d_out);
    } catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << "\n";
        cudaFree(d_in);
        cudaFree(d_out);
        return 1;
    }

    std::vector<float> h_out(size);
    cudaMemcpy(h_out.data(), d_out, size * sizeof(float), cudaMemcpyDeviceToHost);

    std::cout << "Result: ";
    for (auto v : h_out) std::cout << v << " ";
    std::cout << "\n";

    cudaFree(d_in);
    cudaFree(d_out);
    return 0;
}

Building

Without CMake:

bash
nvcc -c cudaFunction.cu -o cudaFunction.o
g++ -c caller.cpp -o caller.o
g++ caller.o cudaFunction.o -o app -L/usr/local/cuda/lib64 -lcudart

With CMake:

cmake
add_library(cudaFunction STATIC cudaFunction.cu)
set_target_properties(cudaFunction PROPERTIES CUDA_SEPARABLE_COMPILATION ON)

add_executable(caller caller.cpp)
target_link_libraries(caller PRIVATE cudaFunction cuda_runtime)

Why This Works

  • All CUDA specifics are isolated in .cu.
  • The host compiler sees only a clean function declaration.
  • Adding new (M, N) combinations is as simple as adding a line to dispatchTable.
  • Kernels still get full compile-time optimization and loop unrolling.

This pattern keeps your runtime code flexible without sacrificing performance.

Drawbacks:

  • You will only notice during runtime if a kernel was not compiled.

Created by Urs Hofmann