LST¶
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from iactrace import MCIntegrator, Telescope, hexshow, show_telescope
Import LST telescope configuration:¶
telescope = Telescope.from_yaml('../../configs/CTAO/LST_1_North_like.yaml', MCIntegrator(256), key = jax.random.key(42))
Visualizing telescope:¶
scene = show_telescope(telescope)
scene.show(viewer='jupyter')
Simulating a star field:¶
# Generate star field
n_stars = 1000
key = jax.random.key(42)
key1, key2 = jax.random.split(key)
# Small angular region (4.3 degrees field of view)
fov_deg = 4.3
fov_rad = fov_deg * jnp.pi / 180
x = jax.random.uniform(key1, (n_stars,), minval=-fov_rad/2, maxval=fov_rad/2)
y = jax.random.uniform(key2, (n_stars,), minval=-fov_rad/2, maxval=fov_rad/2)
z = -jnp.ones(n_stars)
stars = jnp.stack([x, y, z], axis=1)
stars = stars / jnp.linalg.norm(stars, axis=1, keepdims=True)
# Set random flux values for stars:
key = jax.random.key(4242)
f_stars = 10**(-10*jax.random.uniform(key, shape=(len(stars),)))
# Render science image
image_science = telescope.render(stars, f_stars, source_type='parallel', sensor_idx=0)
# Visualize
fig, ax = plt.subplots(ncols=1, figsize=(12, 12))
ax_hex = hexshow(image_science, telescope.sensors[0], ax=ax)
Simulating a point source at finite distance:¶
# Creating a random point at 400m distance to create Bokeh effect:
N_points = 1
key = jax.random.key(45)
key1, key2 = jax.random.split(key)
x = jax.random.uniform(key1, N_points, minval=-0.1, maxval=0.1)
y = jax.random.uniform(key2, N_points, minval=-0.1, maxval=0.1)
z = jnp.ones(N_points) * 400
points = jnp.array([x,y,z]).T
f_points = jnp.ones(len(points))
# Render point source images
image_science = telescope.render(points, f_points, source_type='point', sensor_idx=0)
# Visualize
fig, ax = plt.subplots(ncols=1, figsize=(12,12))
ax_hex = hexshow(image_science, telescope.sensors[0], ax=ax)
Tracing actual rays:¶
Sometimes you just need to trace actual rays…
%%time
# We create a disk of 12m radius of 100000 rays:
n_rays = 100000
key = jax.random.key(123)
key1, key2 = jax.random.split(key)
r = 12.0 * jnp.sqrt(jax.random.uniform(key1, (n_rays,)))
theta = jax.random.uniform(key2, (n_rays,)) * 2 * jnp.pi
ray_origins = jnp.stack([
r * jnp.cos(theta),
r * jnp.sin(theta),
jnp.ones(n_rays) * 40.0
], axis=1)
# Tilt rays by 1.5 degrees
tilt_angle = 1.5 * jnp.pi / 180
ray_directions = jnp.broadcast_to(
jnp.array([jnp.sin(tilt_angle), 0.0, -jnp.cos(tilt_angle)]),
(n_rays, 3)
)
ray_values = jnp.ones(n_rays)
# Trace through telescope and assign to sensor + debug mode for spot diagram:
image_rays = telescope.trace(ray_origins, ray_directions, ray_values, sensor_idx=0)
ray_debug = telescope.trace(ray_origins, ray_directions, ray_values, sensor_idx=0, debug=True)
# Visualize
fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
ax_hex = hexshow(image_rays, telescope.sensors[0], ax=ax[0])
ax_scatter = ax[1].hist2d(ray_debug[0][ray_debug[2] > 0, 0], ray_debug[0][ray_debug[2] > 0, 1], bins=100, norm='log')
CPU times: user 2.32 s, sys: 411 ms, total: 2.73 s
Wall time: 1.96 s