Runtime Dispatch for Templated CUDA Kernels
When writing CUDA code, you often face a trade-off:
- Compile-time parameters (via templates) allow the compiler to fully unroll loops and optimize aggressively.
- Runtime parameters make the host code more flexible, but you lose compile-time benefits unless you generate every possibility.
If you try to mix both, you often end up with a big ugly switch statement that lists all possible template instantiations.
This post shows a clean way to solve this problem using a dispatch table and a clean separation between host and device code.
The Problem
Imagine you have a CUDA kernel templated on two integers M and N:
template<int M, int N>
__global__ void myKernel(const float* in, float* out);You want:
- Device side:
MandNknown at compile time for full unrolling and optimization. - Host side:
MandNpassed in at runtime.
Naively, you might write:
switch (M) {
case 1:
switch (N) {
case 1: myKernel<1,1><<<...>>>(); break;
case 2: myKernel<1,2><<<...>>>(); break;
// ...
}
break;
// ...
}But this quickly becomes unmanageable.
The Solution: Dispatch Table
We can store function pointers to specific template instantiations in a table, and look them up at runtime.
File Structure
We’ll split the code into:
caller.cpp <-- host-only code, runtime M & N
cudaFunction.h <-- interface only, host-safe
cudaFunction.cu <-- implementation, compiled by nvcccudaFunction.h
This header exposes a clean interface:
#pragma once
#include <utility> // for std::pair
// Public API: callable from pure host code
void launchCudaFunction(int M, int N, const float* d_in, float* d_out);cudaFunction.cu
This is compiled with nvcc and contains all CUDA-specific code:
#include "cudaFunction.h"
#include <cuda_runtime.h>
#include <unordered_map>
#include <stdexcept>
template<int M, int N>
__global__ void myKernel(const float* in, float* out) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < M * N) {
out[idx] = in[idx] * 2.0f; // Example work
}
}
template<int M, int N>
void kernelLauncher(const float* in, float* out) {
dim3 threads(256);
dim3 blocks((M * N + threads.x - 1) / threads.x);
myKernel<M, N><<<blocks, threads>>>(in, out);
}
using Key = std::pair<int,int>;
using LauncherFn = void (*)(const float*, float*);
// Robust hash_combine function (similar to boost::hash_combine)
inline std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
struct PairHash {
std::size_t operator()(const Key& k) const {
std::size_t temp = std::hash<int>()(k.first);
std::size_t temp2 = std::hash<int>()(k.second);
return hash_combine(temp, temp2);
}
};
static const std::unordered_map<Key, LauncherFn, PairHash> dispatchTable = {
{{1, 1}, kernelLauncher<1, 1>},
{{1, 2}, kernelLauncher<1, 2>},
{{2, 4}, kernelLauncher<2, 4>},
{{4, 8}, kernelLauncher<4, 8>}
};
void launchCudaFunction(int M, int N, const float* d_in, float* d_out) {
if (auto it = dispatchTable.find({M, N}); it != dispatchTable.end()) {
it->second(d_in, d_out);
} else {
throw std::runtime_error("Unsupported M, N combination");
}
}caller.cpp
This is pure host code and does not need nvcc:
#include "cudaFunction.h"
#include <iostream>
#include <vector>
#include <cuda_runtime.h>
int main() {
int M = 2; // runtime value
int N = 4; // runtime value
size_t size = M * N;
float* d_in;
float* d_out;
cudaMalloc(&d_in, size * sizeof(float));
cudaMalloc(&d_out, size * sizeof(float));
std::vector<float> h_in(size, 1.0f);
cudaMemcpy(d_in, h_in.data(), size * sizeof(float), cudaMemcpyHostToDevice);
try {
launchCudaFunction(M, N, d_in, d_out);
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << "\n";
cudaFree(d_in);
cudaFree(d_out);
return 1;
}
std::vector<float> h_out(size);
cudaMemcpy(h_out.data(), d_out, size * sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "Result: ";
for (auto v : h_out) std::cout << v << " ";
std::cout << "\n";
cudaFree(d_in);
cudaFree(d_out);
return 0;
}Building
Without CMake:
nvcc -c cudaFunction.cu -o cudaFunction.o
g++ -c caller.cpp -o caller.o
g++ caller.o cudaFunction.o -o app -L/usr/local/cuda/lib64 -lcudartWith CMake:
add_library(cudaFunction STATIC cudaFunction.cu)
set_target_properties(cudaFunction PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
add_executable(caller caller.cpp)
target_link_libraries(caller PRIVATE cudaFunction cuda_runtime)Why This Works
- All CUDA specifics are isolated in
.cu. - The host compiler sees only a clean function declaration.
- Adding new
(M, N)combinations is as simple as adding a line todispatchTable. - Kernels still get full compile-time optimization and loop unrolling.
This pattern keeps your runtime code flexible without sacrificing performance.
Drawbacks:
- You will only notice during runtime if a kernel was not compiled.