MST

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from iactrace import MCIntegrator, Telescope, hexshow, show_telescope

Import MST telescope configuration:

telescope = Telescope.from_yaml('../../configs/CTAO/MST_North_like.yaml', MCIntegrator(256), key = jax.random.key(42))

Visualizing telescope:

scene = show_telescope(telescope)
scene.show(viewer='jupyter')

Simulating a star field:

# Generate star field
n_stars = 1000
key = jax.random.key(42)
key1, key2 = jax.random.split(key)

# Small angular region (7.7 degrees field of view)
fov_deg = 7.7
fov_rad = fov_deg * jnp.pi / 180

x = jax.random.uniform(key1, (n_stars,), minval=-fov_rad/2, maxval=fov_rad/2)
y = jax.random.uniform(key2, (n_stars,), minval=-fov_rad/2, maxval=fov_rad/2)
z = -jnp.ones(n_stars)

stars = jnp.stack([x, y, z], axis=1)
stars = stars / jnp.linalg.norm(stars, axis=1, keepdims=True)

# Set random flux values for stars:
key = jax.random.key(4242)
f_stars = 10**(-10*jax.random.uniform(key, shape=(len(stars),)))

# Render science image
image_science = telescope.render(stars, f_stars, source_type='parallel', sensor_idx=0)

# Visualize
fig, ax = plt.subplots(ncols=1, figsize=(12, 12))
ax_hex = hexshow(image_science, telescope.sensors[0], ax=ax)
../_images/87035074f87533445e27d5427d4bf0115be70c3620f8fe4da35866b0a25d6375.png

Simulating a point source at finite distance:

# Creating a random point at 100m distance to create Bokeh effect:
N_points = 1

key = jax.random.key(45)
key1, key2 = jax.random.split(key)

x = jax.random.uniform(key1, N_points, minval=-0.1, maxval=0.1)
y = jax.random.uniform(key2, N_points, minval=-0.1, maxval=0.1)
z = jnp.ones(N_points) * 100

points = jnp.array([x,y,z]).T
f_points = jnp.ones(len(points))

# Render point source image
image_science = telescope.render(points, f_points, source_type='point', sensor_idx=0)

# Visualize
fig, ax = plt.subplots(ncols=1, figsize=(12,12))
ax_hex = hexshow(image_science, telescope.sensors[0], ax=ax)
../_images/1ef34f3de6eb635157d75d2b98d6f6d27ffffb25c8b12f934f220cb2c0a33691.png

Tracing actual rays:

Sometimes you just need to trace actual rays…

# We create a disk of 10m radius of 10000 rays:
n_rays = 10000
key = jax.random.key(123)
key1, key2 = jax.random.split(key)

r = 10.0 * jnp.sqrt(jax.random.uniform(key1, (n_rays,)))
theta = jax.random.uniform(key2, (n_rays,)) * 2 * jnp.pi
ray_origins = jnp.stack([
    r * jnp.cos(theta),
    r * jnp.sin(theta),
    jnp.ones(n_rays) * 20.0
], axis=1)

# Tilt rays by 2.5 degrees
tilt_angle = 2.5 * jnp.pi / 180
ray_directions = jnp.broadcast_to(
    jnp.array([jnp.sin(tilt_angle), 0.0, -jnp.cos(tilt_angle)]),
    (n_rays, 3)
)
ray_values = jnp.ones(n_rays)

# Trace through telescope and assign to sensor + debug mode for spot diagram:
image_rays = telescope.trace(ray_origins, ray_directions, ray_values, sensor_idx=0)
ray_debug = telescope.trace(ray_origins, ray_directions, ray_values, sensor_idx=0, debug=True)

# Visualize
fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
ax_hex = hexshow(image_rays, telescope.sensors[0], ax=ax[0])
ax_scatter = ax[1].hist2d(ray_debug[0][ray_debug[2] > 0, 0], ray_debug[0][ray_debug[2] > 0, 1], bins=50, norm='log')
../_images/f7368c0e7b774cedb79db47a85703bb7f8ca92c268417d8a9ed42d9fba7f9e3f.png